-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch 2.1 compile + FSDP (mixed precision) + LlamaForCausalLM: RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
#111317
Comments
"""
A minimal reproduction of torch.compile, FSDP and torch 2.1 error.
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
Please ensure that the gradient and the tensor have the same dtype
"""
import logging
import math
import os
import time
from pathlib import Path
import functools
from functools import partial
import math
from torch.optim.lr_scheduler import LambdaLR
import datasets
import torch
import transformers
import tokenizers
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from accelerate.utils.dataclasses import ProjectConfiguration
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import (MixedPrecision,
FullStateDictConfig,
FullOptimStateDictConfig,
ShardingStrategy,
BackwardPrefetch,
StateDictType
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# from pretraining.losses import get_lm_loss_func
# from pretraining.scheduler import get_cosine_one_cycle_scheduler
# from pretraining.utils import (get_param_counts,
# enable_gradient_checkpointing,
# get_optim_param_groups)
from pretraining.arguments.pretraining_arguments import parse_args
from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# For speed testing purposes only.
from pretraining.benchmarking.dummy_data_utils import create_dummy_dataloaders
logger = get_logger(__name__, log_level="INFO")
def main():
args = parse_args()
mixed_precision_policy = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.float32)
# Add embedding and lm head if needed.
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
LlamaDecoderLayer,
},
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
auto_wrap_policy=llama_auto_wrap_policy,
mixed_precision_policy=mixed_precision_policy,
state_dict_config=FullStateDictConfig(offload_to_cpu=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE,
state_dict_type=StateDictType.FULL_STATE_DICT,
forward_prefetch=False,
use_orig_params=True,
cpu_offload=False,
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin,
log_with=args.report_to,
project_config=ProjectConfiguration(project_dir=None,
logging_dir=args.output_dir,
automatic_checkpoint_naming=False),
gradient_accumulation_steps=args.gradient_accumulation_steps)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info(f"FSDP Mixed Precision Policy: {accelerator.state.fsdp_plugin.mixed_precision_policy}")
logger.info(f"Native AMP is enabled: {accelerator.native_amp}")
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
if not torch.backends.cuda.matmul.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
logger.info("Setting torch.backends.cuda.matmul.allow_tf32 = True")
else:
logger.info("Already set: torch.backends.cuda.matmul.allow_tf32 = True")
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
if not torch.backends.cudnn.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
logger.info("Setting torch.backends.cudnn.allow_tf32 = True")
else:
logger.info("Already set: torch.backends.cudnn.allow_tf32 = True")
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Create output and experiment directory if needed
task_prefix = 'clm' if not args.prefix_lm else 'plm'
experiment_name = (f"{task_prefix}_{args.model_size}_{args.optimizer}"
if args.experiment_name is None else args.experiment_name)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, experiment_name), exist_ok=True)
accelerator.wait_for_everyone()
# Initialize config.
# elif args.model_size == "1B":
model_size_config = dict(hidden_size=2048,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=4096)
# download model weights and config files.
config = LlamaConfig()
config.update(model_size_config)
logger.info(f"Model config: {config.to_json_string()}")
# Load tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.custom_tokenizer_path, use_fast=True)
# Update config with vocab size.
config.vocab_size = len(tokenizer.vocab)
# Initialize from pretrained LLaMa model.
model = LlamaForCausalLM(config)
# In case tokenizer has extra tokens.
prev_shape = model.model.embed_tokens.weight.size()
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Resized word embeddings from: {prev_shape} to: {model.model.embed_tokens.weight.size()}")
assert model.model.embed_tokens.weight.size(1) % 8 == 0, f"embed_tokens must be divisible by 8."
assert model.model.embed_tokens.weight.size() == model.lm_head.weight.size(), \
(f"embed_tokens {model.model.embed_tokens.weight.size()} "
f"and lm_head {model.lm_head.weight.size()} shapes must be the same.")
# Create dataloaders.
(train_dl,valid_dl,train_ds,valid_ds) = create_dummy_dataloaders(args.dataset_name,
tokenizer,
args.block_size,
args.per_device_train_batch_size,
packed_inputs=False,
prefix_lm=False)
dataset_length = len(train_ds)
# Total batch size.
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info(f"Total batch size: {total_batch_size}")
# Initialize loss func. Using torch.nn.CrossEntropyLoss().
# loss_func = get_lm_loss_func(apex=False, flash_attn=False)
class LMLossTorch:
def __init__(self):
self.loss_fct = torch.nn.functional.cross_entropy
logger.info("Using PyTorch cross_entropy")
def compute(self, logits, labels, loss_mask=None, z_loss:float=None, ignore_index=-100):
# Ignore prediction for the last token.
shift_logits = logits[...,:-1,:].contiguous()
shift_labels = labels[...,1:].contiguous().long()
if loss_mask is not None:
loss_mask = loss_mask[...,:-1].contiguous().bool()
shift_labels.masked_fill_(~loss_mask, ignore_index)
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# args: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0
loss = self.loss_fct(shift_logits, shift_labels, ignore_index=ignore_index)
if z_loss is not None:
log_z = torch.logsumexp(shift_logits[shift_labels!=ignore_index], dim=-1)**2
z_loss_val = z_loss*log_z.mean()
loss += z_loss_val
return loss, z_loss_val
return loss, None
loss_func = LMLossTorch()
# Calculate total number training steps.
num_update_steps_per_epoch = math.ceil(len(train_dl) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
logger.info(f"num_update_steps_per_epoch (before prepare) : {num_update_steps_per_epoch}")
logger.info(f"max_train_steps (before prepare) : {args.max_train_steps}")
# Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
total_scheduler_steps = int(args.max_train_steps*accelerator.num_processes)
if (args.num_warmup_steps is not None) and (args.num_warmup_fraction is not None):
raise ValueError("Only one of num_warmup_steps or num_warmup_fraction can be specified.")
if args.num_warmup_steps is not None:
scheduler_warmup_steps = args.num_warmup_steps
elif args.num_warmup_fraction is not None:
scheduler_warmup_steps = int(args.num_warmup_fraction * total_scheduler_steps)
# prepare model first in FSDP.
model = accelerator.prepare(model)
def get_optim_param_groups(model, weight_decay, no_decay):
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
decay_names = [n for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
no_decay_names = [n for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
return optimizer_grouped_parameters, no_decay_names, decay_names
optimizer_grouped_parameters, no_decay_names, decay_names = get_optim_param_groups(model, args.weight_decay,
no_decay=["embed", "bias", "norm.weight"])
newline = '\n'
logger.info(f"No decay: {newline.join(no_decay_names)}")
logger.info(f"Decay: {newline.join(decay_names)}")
# NOTE: Fused AdamW didn't work with FSDP.
optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(0.9,0.95),
eps=1e-5)
def _get_cosine_one_cycle_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_lr_fraction = 0.1,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
scale_term = (1 - min_lr_fraction)
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return (math.cos(math.pi * progress)+1) * 0.5 * scale_term + min_lr_fraction
def get_cosine_one_cycle_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_fraction=0.1):
lr_lambda = partial(
_get_cosine_one_cycle_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
min_lr_fraction=min_lr_fraction
)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
lr_scheduler = get_cosine_one_cycle_scheduler(optimizer,
num_warmup_steps=scheduler_warmup_steps,
num_training_steps=total_scheduler_steps,
min_lr_fraction=args.min_lr_fraction)
# Prepare the remaining objects with our `accelerator`.
optimizer, train_dl, valid_dl, lr_scheduler = accelerator.prepare(
optimizer, train_dl, valid_dl, lr_scheduler
)
if isinstance(lr_scheduler, accelerate.scheduler.AcceleratedScheduler):
logger.info(f"Using AcceleratedScheduler")
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
# Based on dataloader preparation strategy, for IterableDataset batches will be dispatched from main process.
if args.total_dataset_tokens is None:
num_update_steps_per_epoch = math.ceil(len(train_dl) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
logger.info(f"num_update_steps_per_epoch (after prepare) : {num_update_steps_per_epoch}")
logger.info(f"max_train_steps (after prepare) : {args.max_train_steps}")
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
tracker_project_name = Path(args.output_dir).parent.name
accelerator.init_trackers(tracker_project_name, experiment_config)
accelerator.log({"model_config": config.to_json_string()})
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {dataset_length}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
# Batches to be used for sophia hessian estimation.
total_loss = 0
step_total_loss = 0
step_total_zloss = 0
ema_loss = None
for step, batch in enumerate(train_dl):
# forward pass: `input_ids`, `decoder_causal_attention`, `decoder_segment_ids`
batch['input_ids'] = batch.pop(args.input_ids_key)
labels = batch["input_ids"].clone()
loss_mask = batch.pop('decoder_loss_mask', None)
_ = batch.pop('decoder_causal_attention', None)
# TODO: Check if this is faster than manual block diagonal attention bias.
if args.multipacked_inputs:
batch['input_ids'] = batch['input_ids'].view(-1).unsqueeze(0)
labels = labels.view(-1).unsqueeze(0)
loss_mask = loss_mask.view(-1).unsqueeze(0)
# TODO: Disable for now. May be affecting speed.
# batch['position_ids'] = batch['position_ids'].view(-1).unsqueeze(0)
offset = torch.cat([torch.tensor([0], device=batch['decoder_segment_ids'].device),
batch['decoder_segment_ids'][:-1,-1].cumsum(0)]).view(-1,1)
batch['attention_mask'] = (batch['decoder_segment_ids'] + offset).view(-1).unsqueeze(0)
outputs = model(**batch)
# compute loss: `logits`, `labels` `decoder_loss_mask`
loss, _ = loss_func.compute(outputs.logits, labels, loss_mask,
z_loss=args.z_loss, ignore_index=args.ignore_index)
# We keep track of the loss at each epoch
accelerator.backward(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
# No effect with deepspeed, can keep. Use deepspeed config to enable it.
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.gradient_clipping)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
step_total_loss = step_total_loss / args.gradient_accumulation_steps
step_total_zloss = step_total_zloss / args.gradient_accumulation_steps
logger.info("Training complete!")
if __name__ == "__main__":
main() """
Create test dataloader
"""
"""
Prepare dummy tokenized dataset with 1% of data.
"""
from itertools import chain
import functools
from datasets import load_dataset
from transformers import default_data_collator
from torch.utils.data import DataLoader
import numpy as np
def tokenize_function(examples,tokenizer,text_column_name):
# no attention mask needed.
return tokenizer(examples[text_column_name], return_attention_mask=False, return_token_type_ids=False)
def group_texts(examples, block_size, packed_inputs, prefix_lm):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if packed_inputs:
result["decoder_segment_ids"] = [[1]*len(l) for l in result["input_ids"]]
if prefix_lm:
result["decoder_causal_attention"] = [[0]*len(l) for l in result["input_ids"]]
return result
def create_dummy_dataloaders(dataset_name, tokenizer, sequence_length, batch_size, packed_inputs=True, prefix_lm=False):
raw_datasets = load_dataset(dataset_name)
raw_datasets["validation"] = load_dataset(
dataset_name,
split=f"train[:1%]",
)
# Use dummy dataset for benchmarking set to 1%.
raw_datasets["train"] = load_dataset(
dataset_name,
split=f"train[:1%]",
)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
tokenized_datasets = raw_datasets.map(
functools.partial(tokenize_function, tokenizer=tokenizer, text_column_name=text_column_name),
batched=True,
num_proc=None,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on dataset",
)
lm_datasets = tokenized_datasets.map(
functools.partial(group_texts, block_size=sequence_length, packed_inputs=packed_inputs, prefix_lm=prefix_lm),
batched=True,
num_proc=None,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length}",
)
lm_datasets = lm_datasets.rename_column("input_ids", "targets")
train_dataloader = DataLoader(
lm_datasets["train"], shuffle=False, collate_fn=default_data_collator, batch_size=batch_size
)
eval_dataloader = DataLoader(
lm_datasets["validation"], collate_fn=default_data_collator, batch_size=batch_size
)
return train_dataloader, eval_dataloader, lm_datasets["train"], lm_datasets["validation"] |
LlamaForCausalLM
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
@wconstab , assigned to you temporarily and feel free to assign to the right owner. |
Previous issue possibly due to fsdp + mixed precision + dynamo + autograd (assigning gradient of incorrect dtype): #110797 |
Actually, couldn't this be the root cause of it all? #111794. Not handling autocast manager setup/unwind appropriately when using DDP. Not sure if this bleeds into FSDP. Could you confirm if the EDIT: turns out the answer is no, but does FSDP introduce its own graph breaks? |
I still see autocast in your stacktrace. Though you claim to have disabled it, I think that some of the model code has been decorated with autocast. |
The fx graph splitter pass is not used at all for FSDP. FSDP's graph breaks happen more implicitly: Dynamo sees fsdp's python code and is configured to graph-break on it. In the ddp case, there is not really any DDP code running during forward, and the reason to insert the graph-breaks only becomes apparent when you consider the comm operations that happen during backward. So we had to take another approach to add the graph-breaks, using the fx pass. |
Sorry, I only included one of the stacktraces, in this case it happens to be the autocasted run but same error happened without it as well. |
@KeremTurgutlu, may I ask what are your settings for the supposedly non-autocasted run for these? mixed_precision_policy = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.float32)
logger.info(f"FSDP Mixed Precision Policy: {accelerator.state.fsdp_plugin.mixed_precision_policy}")
logger.info(f"Native AMP is enabled: {accelerator.native_amp}") Could you share the logs so I can confirm? |
@wconstab are you still working on this? |
no, i never got a chance to look into this. unassigning for now |
Seeing the same issue when using combination of FSDP + compile + MP. Using amp autocast I don't see these issues, but whenever I add the |
Hi @jon-chuang, I can confirm that even using the setting 'param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32' you given, the same error occurs. Please see logs following: |
additional info: this issue disappears on pytorch 2.0.0 and pytorch 2.1.2 |
Hit similar error while use FSDP and compile together. If I set dtype of mixed precision to fp32, the error turns to tensor size mismatch: It seems the gradient (FlatTensor) in FSDP mode is not well caught by compile. My torch version is 2.2 and the way I use FSDP is like this:
|
@tangjiasheng did you manage to find a solution for the size mismatch? Running into the same error, and this is the only post referencing it I could find. |
Sorry, the answer is NO... |
We are also hitting this issue with SHARD_GRAD_OP. Specifically, we are hitting it on the first eval step, after a training step has run on training Stable Diffusion 2. If we never do evaluation / saving the model seems to run fine. Hopefully this helps narrow it down a bit @yf225 . I can confirm that this issue is present both on 2.2 as well as on nightly as of last week. |
My current hypothesis is that these errors happen due to OOM, but no proper OOM error is present but the ones discussed in this thread. |
Also met the issue for bf16 training and fp16 training.
|
Also fail with |
I encounter this issue when I run an evaluation step after a training step. If I do evaluation before any training steps, torch compile works fine and runs training until the first evaluation step after the torch compile backward passes. Seems like this has to do with it not respecting |
This is the case for me. Reducing batch size worked |
Also hitting this issue with torch version: 2.3.0a0+6ddf5cf85e.nv24.04 As Skylion007 mentioned, this issue occurs for me if I do train before eval, but it goes away if I do eval before train. |
any updates? running into |
#134614 might be able to fix this issue. |
#134614 should fix this issue. Please reopen the issue if it's not the case. |
🐛 Describe the bug
I am getting the following error when training LlamaForCausalLM with torch 2.1 and FSDP (mixed precision) and torch.compile. Same exact code works when torch.compile disabled or when torch 2.0.1 is used. I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.
I am using a docker image, error happens in Environment 2 which is provided in the Versions section.
Error logs
Minified repro
No response
Versions
Environment 1
Environment 2
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @wanchaol @fduwjj @wz337 @kiukchung @d4l3k @LucasLLC @tianyu-l @gchanan @kadeng
The text was updated successfully, but these errors were encountered: