Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 doc and example #940

Merged
merged 14 commits into from
Jan 22, 2024
8 changes: 4 additions & 4 deletions atorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ATorch is an extension library of PyTorch developed by Ant Group's AI Infrastruc
## Installation

ATorch supports PyTorch with version >= 1.12, and version 2.1 or above is preferred.
For example, you can use docker image <code>easydl/atorch:iml_pt210</code> (or aliyun mirror <code>registry.cn-hangzhou.aliyuncs.com/atorch/atorch:iml_pt210</code>) which has PyTorch 2.1 installed.
For example, you can use docker image <code>registry.cn-hangzhou.aliyuncs.com/atorch/atorch:pt210_te</code>) which has PyTorch 2.1 installed.

### Install From PyPI
Install atorch in any PyTorch-preinstalled environment (such as a container created with the docker image above) with <code>pip</code>:
Expand All @@ -76,9 +76,9 @@ pip install atorch
# clone repository
git clone https://github.com/intelligent-machine-learning/dlrover.git
cd dlrover/atorch
# build package
sh dev/scripts/build.sh
# install the created package in dist directory
# build package, optional set version.
sh dev/scripts/build.sh [version]
# install the created package in dist directory. Not that if version is set, file name is different.
pip install dist/atorch-0.1.0.dev0-py3-none-any.whl
```

Expand Down
8 changes: 7 additions & 1 deletion atorch/atorch/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .coworker_dataset import build_coworker_dataloader
from .data_utils import expand_batch_dim, get_sample_batch
from .elastic_dataloader import build_coworker_dataloader_with_elasticdl, get_elastic_dataloader

try:
from .elastic_dataloader import build_coworker_dataloader_with_elasticdl, get_elastic_dataloader
except TypeError:
print("protobuf version mismatch. elastic_dataloader cannot be used")
build_coworker_dataloader_with_elasticdl = None
get_elastic_dataloader = None
from .preloader import GpuPreLoader, data_to_device
from .shm_context import ShmData, create_coworker_shm_context
from .shm_dataloader import ShmDataloader, create_shm_dataloader
Expand Down
5 changes: 3 additions & 2 deletions atorch/dev/docker/Dockerfile-ubuntu2004-pt210
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ RUN apt-get install -y ibverbs-utils rdma-core && \
##############################################################################
FROM torch-base as atorch-fa-base
ENV USE_NCCL=1
ARG TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6+PTX"
ARG TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
# RUN yum install libnccl-2.16.2-1+cuda11.0 libnccl-devel-2.16.2-1+cuda11.0 -y && \
RUN pip install https://dlrover.oss-cn-beijing.aliyuncs.com/atorch/libs/fastmoe.tar.gz
RUN pip install dm-tree setuptools packaging && \
Expand All @@ -72,7 +72,8 @@ RUN pip install dm-tree setuptools packaging && \
cd .. && rm -rf apex*
RUN git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git && \
cd TransformerEngine && \
CUDACXX=/usr/local/cuda/bin/nvcc NVTE_FRAMEWORK=pytorch pip install . && \
CUDACXX=/usr/local/cuda/bin/nvcc NVTE_FRAMEWORK=pytorch python setup.py bdist_wheel && \
pip install dist/transformer_engine*.whl && \
cd .. && rm -rf TransformerEngine

# make sure to rebuild fa if you updated py/cuda/... versions are updated.
Expand Down
71 changes: 70 additions & 1 deletion atorch/docs/auto_accelerate_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Not None, semi-automatic model. Supported formats:


<tr>
<td>ignore_dryrun_on_load_strategy (optional, default False)</td>
<td>ignore_dryrun_on_load_strategy (optional, default True)</td>
<td>
If True, ignore dryrun when load_strategy is not None.
</td>
Expand Down Expand Up @@ -263,6 +263,56 @@ For bfloat16, it does not check if the gradients are infinite by default. If you

Training in half precision. Default configuration is <code>"fp16"</code>. If want to use bfloat16, set config as <code>"bf16"</code>.

### fp8

Use the FP8 capability provided by [transformer_engine](https://github.com/NVIDIA/TransformerEngine) (te) to accelerate computation. This optimization method will automatically replace <code>nn.Linear</code> module in the model with <code>te.Linear</code> to speed up computation. fp8 is compatible with other optimization methods such as [amp_native](#amp_native), [half](#half), [fsdp](#fsdp), [checkpoint](#checkpoint), etc.
Note that lora([peft](https://github.com/huggingface/peft)) fp8 training is not supported yet.

**Pre-requisites**
- Hardware support: GPU sm >=8.9 (such as Ada, Hopper, etc.). If not satisfied, fp8 optimization will be ignored.
- Software support: transformer_engine installed, version >= 1.0.
- Tensor dimension requirements: For tensor core fp8 computation, tensor dim[0] must be a multiple of 8, and dim[1] must be a multiple of 16. Since the backward computation of <code>nn.Linear</code> during training requires a transpose op, this means that both the weight of <code>nn.Linear</code> and the module's input need dim[0] and dim[1] to be multiples of 16. For weight dimensions, fp8 optimization method will check automatically, and it is up to the users to ensure that the input to <code>nn.Linear</code> also meets this dimension requirement.

**Supported config parameters**

```
include: List[str], default None.
If None, all nn.Linear module can use te.
If not None, nn.Linear module name should have at least one substring equals to items in include.
exclude: List[str], default None.
If None, all modules that passing include test would use te.
If not None, if a nn.Linear module name has at least one substring matches exclude, it will not use te.
verbose: Bool, default False. If True, print names of those submodules that are replaced by <code>te.Linear </code>.
recipe.DelayedScaling parameter:
margin: default 0
interval: default 1
fp8_format: “HYBRID” (default) or “E4M3”
amax_history_len: default 1024
amax_compute_algo: “max” (default) or “most_recent”
reduce_amax: default True
```

**Default config**
```
{"include": None, "exclude": None, "margin": 0, "interval": 1, "fp8_format": "HYBRID", "amax_history_len": 1024, "amax_compute_algo": "max", "reduce_amax": True}
```

All <code>nn.Linear</code> instances that pass the "include" and "exclude" conditions and whose weight dim[0] and dim[1] are multiples of 16 will be automatically converted to <code>te.Linear</code>, using <code>recipe.DelayedScaling</code> defined by the config parameters excluding "include" and "exclude" for automatic fp8 computation.

**Example**

In a [llama](https://github.com/huggingface/transformers/blob/53cffeb33c2f46ccef8969da076aa620f70801c4/src/transformers/models/llama/modeling_llama.py#L1106) model, <code>nn.Linear</code> exists not only in the <code>LlamaDecoderLayer</code> but also <code>lm_head</code> . Using fp8 training for <code>nn.Linear</code> in <code>LlamaDecoderLayer</code> usually does not affect the convergence accuracy, but it has a severe impact when <code>lm_head</code> also uses fp8. In this case, you can use the config so that the module replacement only affects <code>the LlamaDecoderLayer</code> and not the <code>lm_head</code>.

This can be achieved using <code>include</code> config parameter:

<code>config = {"include": ("layers",)}</code>

Or using <code>exclude</code> config parameter:

<code>config = {"exclude": ("lm_head",)}</code>



### module_replace

Automatic module optimization, which replaces optimizable modules with optimized implementations.
Expand Down Expand Up @@ -320,6 +370,25 @@ config = {"forward_prefetch": True, "limit_all_gathers": True, "sync_module_stat

Add <code>{"use_orig_params": True}</code> if multiple parameter groups with different hyperparamters are used in optimizer. Try add <code>{"fsdp_wrap_params_outmost": True}</code> for LORA finetuning to see if any performance improvement.

### checkpoint

Activation checkpoint is a memory-saving method which trade computation for memory. It does not keep activations during forward pass, but uses recomputation in backward pass to generate activations for gradient computation. Configuration is required to indicate which modules would be checkpointed.

Configuration can be a tuple of module types or module names, such as:
```
config = (GPT2Attention, GPT2MLP)
```

There are two checkpoint implementations in PyTorch, no_reentrant and reentrant. no_reentrant is default and its performance is better than reentrant. In some cases such that model definition contains <code>@torch.jit.script</code>, no_reentrant implementation may fail and reentrant should be used. Checkpoint configuration supports dict format to support choosing reentrant implementation.
```
config = {
"wrap_class": (GPT2Attention, GPT2MLP), # modules to checkpoint
"no_reentrant": False, # use reentrant implementation
}
```



### tensor_parallel

Tensor parallel, which would split modules in Megatron style tensor parallel automatically. The degree of tensor parallelism is specified in parallel_mode configuration, such as <code>("tensor", 8)</code> for degree = 8.
Expand Down
5 changes: 5 additions & 0 deletions atorch/examples/auto_accelerate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,8 @@ Train llama model in semi-automatic mode, using (fsdp, amp_native, module_replac
```
python -m atorch.distributed.run --nproc_per_node 8 train.py --model_type llama --distributed --hidden_size 64 --head_num 4 --layer_num 4 --seq_length 32 --load_strategy --use_fsdp --use_amp --use_module_replace --use_checkpointing --user_created_dataloader
```
And adding fp8 for training with GPU with fp8 capability (sm >=8.9).
```
python -m atorch.distributed.run --nproc_per_node 8 train.py --model_type llama --distributed --hiddien_size 64 --head_num 4 --layer_num 4 --seq_length 32 --load_strategy --use_fsdp --use_amp --use_module_replace --use_checkpointing --use_fp8 --user_created_dataloader
```

9 changes: 5 additions & 4 deletions atorch/examples/auto_accelerate/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@


class ToyDataset(Dataset):
def __init__(self, size, data_size=(16,), output_size=(4,)):
def __init__(self, size, data_size=(16,), input_size=(16,), output_size=(4,)):
self.size = size
self.data_size = data_size
self.input_size = input_size
self.output_size = output_size

def __len__(self):
return self.size

def __getitem__(self, idx):
return {
"input": np.ones(self.data_size, dtype=np.float32) * idx,
"input": np.ones(self.input_size, dtype=np.float32) * idx,
"label": np.ones(self.output_size, dtype=np.float32),
}

Expand Down Expand Up @@ -63,9 +64,9 @@ def __len__(self):
return self.total_samples


def get_dataset(model_type, seq_length=128, datasize=1000):
def get_dataset(model_type, seq_length=128, input_size=16, output_size=8, datasize=1000):
if model_type == ModelType.TOY:
return ToyDataset(size=datasize)
return ToyDataset(size=datasize, input_size=input_size, output_size=output_size)
if model_type == ModelType.GPT2 or model_type == ModelType.LLAMA:
vocab_size = get_vocab_size(model_type)
return BenchmarkLMDataset(vocab_size=vocab_size, max_source_positions=seq_length, total_samples=datasize)
Expand Down
4 changes: 3 additions & 1 deletion atorch/examples/auto_accelerate/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_model(model_type, config):
# config: dict with hidden_size, head_num, layer_num, seq_length for llms

if model_type == ModelType.TOY:
model = ToyModel()
model = ToyModel(
in_features=config["in_features"], out_features=config["out_features"], num_linears=config["num_linears"]
)
return model

# llms
Expand Down
48 changes: 40 additions & 8 deletions atorch/examples/auto_accelerate/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ def parse_args():
parser.add_argument("--layer_num", type=int, default=3, required=False)
parser.add_argument("--seq_length", type=int, default=16, required=False)
parser.add_argument("--batchsize", type=int, default=8, required=False)
parser.add_argument("--in_size", type=int, default=16, required=False)
parser.add_argument("--out_size", type=int, default=8, required=False)
parser.add_argument("--distributed", default=False, action="store_true")
parser.add_argument("--user_created_dataloader", default=False, action="store_true")
parser.add_argument("--load_strategy", default=False, action="store_true")
parser.add_argument("--optim_grouped_params", default=False, action="store_true")
parser.add_argument("--log_interval", type=int, default=10, required=False)
parser.add_argument("--use_fsdp", default=False, action="store_true")
parser.add_argument("--use_amp", default=False, action="store_true")
parser.add_argument("--use_fp8", default=False, action="store_true")
parser.add_argument("--use_checkpointing", default=False, action="store_true")
parser.add_argument("--use_module_replace", default=False, action="store_true")

Expand All @@ -67,18 +70,31 @@ def train(args):

device = "cuda" if torch.cuda.is_available() else "cpu"

# get model, loss_func,
model_config = {
"hiddien_size": args.hiddien_size,
"head_num": args.head_num,
"layer_num": args.layer_num,
"seq_length": args.seq_length,
}
# get model, loss_func
if model_type == ModelType.TOY:
model_config = {
"in_features": args.in_size,
"out_features": args.out_size,
"num_linears": args.layer_num,
}
else:
model_config = {
"hiddien_size": args.hiddien_size,
"head_num": args.head_num,
"layer_num": args.layer_num,
"seq_length": args.seq_length,
}
model = get_model(model_type, model_config)
print("Get model with class ", model.__class__)
loss_func = get_loss_func(model_type)

dataset = get_dataset(model_type, seq_length=args.seq_length, datasize=args.datasize)
dataset = get_dataset(
model_type,
seq_length=args.seq_length,
input_size=args.in_size,
output_size=args.out_size,
datasize=args.datasize,
)
dataloader_args = get_dataloader_args(model_type, batch_size=args.batchsize)

strategy = None
Expand Down Expand Up @@ -108,6 +124,22 @@ def train(args):
if args.use_checkpointing:
checkpoint_modules = (get_module_type(model_type),)
strategy.append(("checkpoint", checkpoint_modules))
# fp8
if args.use_fp8:
if model_type == ModelType.LLAMA:
strategy.append(("fp8", {"include": ("layers",)}))
elif model_type == ModelType.TOY:
if args.in_size % 16 != 0 or args.out_size % 16 != 0 or args.batchsize % 16 != 0:
print(
"fp8 is ignored. To use fp8 for toy model, "
+ "in_size({}), out_size({}) and batchsize({}) must be multiples of 16!".format(
args.in_size, args.out_size, args.batchsize
)
)
else:
strategy.append("fp8")
else:
print("fp8 is ignored for gpt2 model")

# optimizer
if model_type == ModelType.LLAMA:
Expand Down
7 changes: 6 additions & 1 deletion atorch/examples/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This document presents examples of using ATorch to pretrain or finetune the Hugg

## FSDP

Fully Sharded Data Parallel (FSDP) is PyTorch's implementation of ZeRO3. This example uses FSDP for distributed training, and can be used with other training optimizations, such as mixed precision, gradient checkpointing, etc. This is implemented by calling auto_accelerate API with load_strategy argument, and load_strategy specifies the training optimization method combination.
Fully Sharded Data Parallel (FSDP) is PyTorch's implementation of ZeRO3. This example uses FSDP for distributed training, and can be used with other training optimizations, such as mixed precision (fp16/bf16/fp8), gradient checkpointing, etc. This is implemented by calling auto_accelerate API with load_strategy argument, and load_strategy specifies the training optimization method combination.

### Scripts

Expand All @@ -25,10 +25,15 @@ pip install -r requirements.txt
# Configurable environment variable: DATASET_PATH, MODEL_NAME_OR_PATH, PER_DEVICE_TRAIN_BATCH_SIZE, etc.
sh fsdp_llama2_entry.sh

# use fp8
USE_FP8=1 sh fsdp_llama2_entry.sh

# use lora
USE_LORA=1 sh fsdp_llama2_entry.sh
```

Note that transformer_engine is required for fp8. Your can use docker image <code>registry.cn-hangzhou.aliyuncs.com/atorch/atorch:pt210_te</code>, which has transformer_engine pre-installed.

## DS 3D Parallel
### Intro
- For large-scale model training (with 100B+ levels), besides using FSDP/zero3 parallelism, 3D parallelism is widely used in deep learning community. 3D parallelism includes tensor parallel, pipeline parallel, and data parallel. Megatron-LM and DeepSpeed provide excellent 3D parallelism implementation which are popular among users.
Expand Down
10 changes: 10 additions & 0 deletions atorch/examples/llama2/fsdp_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def parse_args():
action="store_true",
help="Use gradient checkpointing or not.",
)
parser.add_argument(
"--fp8",
action="store_true",
help="Use fp8 or not.",
)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -200,6 +205,11 @@ def make_inputs_require_grad(module, input, output):
strategy.append(("half", "bf16"))
if args.gradient_checkpointing:
strategy.append(("checkpoint", (LlamaDecoderLayer,)))
if args.fp8:
if args.peft_type is not None:
logger.warning("fp8 ignored as fp8 for lora training is not implemented yet.")
else:
strategy.append(("fp8", {"include": ("layers",)}))
status, result, best_strategy = auto_accelerate(
model,
torch.optim.AdamW,
Expand Down
8 changes: 7 additions & 1 deletion atorch/examples/llama2/fsdp_llama2_entry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ else
"
fi

if [ -z "$USE_FP8" ]; then
FP8_OPT=""
else
FP8_OPT="--fp8"
fi

python -m atorch.distributed.run \
--nnodes="$WORLD_SIZE" \
--nproc_per_node="$NUM_GPUS_PER_NODE" \
Expand All @@ -30,4 +36,4 @@ python -m atorch.distributed.run \
--per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
--precision bf16_amp \
--gradient_checkpointing \
$LORA_OPT
$LORA_OPT $FP8_OPT
4 changes: 3 additions & 1 deletion atorch/examples/llama2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
datasets>=2.14.6
peft==0.4.0
peft==0.4.0
modelscope
atorch>=0.1.7
Loading