Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch 2.1 compile + FSDP (mixed precision) + LlamaForCausalLM: RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. #111317

Closed
KeremTurgutlu opened this issue Oct 15, 2023 · 26 comments
Assignees
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@KeremTurgutlu
Copy link

KeremTurgutlu commented Oct 15, 2023

🐛 Describe the bug

I am getting the following error when training LlamaForCausalLM with torch 2.1 and FSDP (mixed precision) and torch.compile. Same exact code works when torch.compile disabled or when torch 2.0.1 is used. I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.

RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
Please ensure that the gradient and the tensor have the same dtype

I am using a docker image, error happens in Environment 2 which is provided in the Versions section.

Error logs

  0%|          | 0/5 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    Traceback (most recent call last):
return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return forward_call(*args, **kwargs)outputs = model(**batch)

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    main()    
main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    outputs = model(**batch)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        outputs = model(**batch)
return fn(*args, **kwargs)    
main()  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return forward_call(*args, **kwargs)
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return forward_call(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return fn(*args, **kwargs)    
return fn(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)    
output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)    
return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return self._call_impl(*args, **kwargs)
return fn(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
    return self._call_impl(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
        return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return forward_call(*args, **kwargs)output = self._fsdp_wrapped_module(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
        return self._call_impl(*args, **kwargs)    return model_forward(*args, **kwargs)
return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
    return convert_to_fp32(self.model_forward(*args, **kwargs))  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast

  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        outputs = self.model(
outputs = self.model(  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return convert_to_fp32(self.model_forward(*args, **kwargs))layer_outputs = decoder_layer(

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
        return forward_call(*args, **kwargs)
return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    return forward_call(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    args, kwargs = _pre_forward(
        layer_outputs = decoder_layer(layer_outputs = decoder_layer(  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward


  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
    ret = self._writeback_orig_params()  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard

        return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        return func(*args, **kwargs)ran_pre_unshard = handle.pre_unshard()

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
        return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    ret = self._writeback_orig_params()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
        return self._call_impl(*args, **kwargs)return func(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    args, kwargs = _pre_forward(
    args, kwargs = _pre_forward(  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
Traceback (most recent call last):
        unshard_fn(state, handle)unshard_fn(state, handle)

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
    ran_pre_unshard = handle.pre_unshard()  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard

_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    ret = self._writeback_orig_params()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
        ret = self._writeback_orig_params()ret = self._writeback_orig_params()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
        return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
        flat_param.grad = flat_param_grad
flat_param.grad = flat_param_grad
    ret = self._writeback_orig_params()
RuntimeError  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
============================================================
pretrain_fsdp_torch2.1_minimal.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 348906)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 348907)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 4 (local_rank: 4)
  exitcode  : 1 (pid: 348908)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[4]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 348909)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[5]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 348910)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[6]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 7 (local_rank: 7)
  exitcode  : 1 (pid: 348911)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 348905)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Minified repro

No response

Versions

Environment 1

PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             224
On-line CPU(s) list:                0-223
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7713 64-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 1
Core(s) per socket:                 224
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           3999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization:                     AMD-V
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          14 MiB (224 instances)
L1i cache:                          14 MiB (224 instances)
L2 cache:                           112 MiB (224 instances)
L3 cache:                           3.5 GiB (224 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.0.1
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.0.0
[conda] Could not collect

Environment 2

Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             224
On-line CPU(s) list:                0-223
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7713 64-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 1
Core(s) per socket:                 224
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           3999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization:                     AMD-V
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          14 MiB (224 instances)
L1i cache:                          14 MiB (224 instances)
L2 cache:                           112 MiB (224 instances)
L3 cache:                           3.5 GiB (224 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.1.0
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0+e621604
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @wanchaol @fduwjj @wz337 @kiukchung @d4l3k @LucasLLC @tianyu-l @gchanan @kadeng

@KeremTurgutlu
Copy link
Author

"""
A minimal reproduction of torch.compile, FSDP and torch 2.1 error.

RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
Please ensure that the gradient and the tensor have the same dtype
"""

import logging
import math
import os
import time
from pathlib import Path
import functools
from functools import partial
import math
from torch.optim.lr_scheduler import LambdaLR

import datasets
import torch
import transformers
import tokenizers
from tqdm.auto import tqdm
from transformers import AutoTokenizer

import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from accelerate.utils.dataclasses import ProjectConfiguration
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import (MixedPrecision, 
                                                                FullStateDictConfig, 
                                                                FullOptimStateDictConfig, 
                                                                ShardingStrategy,
                                                                BackwardPrefetch,
                                                                StateDictType
                                                                )
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy


# from pretraining.losses import get_lm_loss_func
# from pretraining.scheduler import get_cosine_one_cycle_scheduler
# from pretraining.utils import (get_param_counts, 
#                                enable_gradient_checkpointing,
#                                get_optim_param_groups)
from pretraining.arguments.pretraining_arguments import parse_args

from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# For speed testing purposes only.
from pretraining.benchmarking.dummy_data_utils import create_dummy_dataloaders

logger = get_logger(__name__, log_level="INFO")


def main():
    args = parse_args()

    mixed_precision_policy = MixedPrecision(param_dtype=torch.bfloat16, 
                                            reduce_dtype=torch.bfloat16, 
                                            buffer_dtype=torch.float32)
    
    # Add embedding and lm head if needed.
    llama_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            LlamaDecoderLayer,
        },
    )
    
    fsdp_plugin = FullyShardedDataParallelPlugin(
                        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
                        auto_wrap_policy=llama_auto_wrap_policy,
                        mixed_precision_policy=mixed_precision_policy,
                        state_dict_config=FullStateDictConfig(offload_to_cpu=False),
                        optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False),
                        backward_prefetch = BackwardPrefetch.BACKWARD_PRE,
                        state_dict_type=StateDictType.FULL_STATE_DICT,
                        forward_prefetch=False,
                        use_orig_params=True,
                        cpu_offload=False,
                    )    

    accelerator = Accelerator(fsdp_plugin=fsdp_plugin,
                    log_with=args.report_to, 
                    project_config=ProjectConfiguration(project_dir=None,
                                                        logging_dir=args.output_dir,
                                                        automatic_checkpoint_naming=False),
                    gradient_accumulation_steps=args.gradient_accumulation_steps) 
                   
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    logger.info(f"FSDP Mixed Precision Policy: {accelerator.state.fsdp_plugin.mixed_precision_policy}")
    logger.info(f"Native AMP is enabled: {accelerator.native_amp}")
    
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    if not torch.backends.cuda.matmul.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        logger.info("Setting torch.backends.cuda.matmul.allow_tf32 = True")
    else:
        logger.info("Already set: torch.backends.cuda.matmul.allow_tf32 = True")
    # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
    if not torch.backends.cudnn.allow_tf32:
        torch.backends.cudnn.allow_tf32 = True
        logger.info("Setting torch.backends.cudnn.allow_tf32 = True")
    else:
        logger.info("Already set: torch.backends.cudnn.allow_tf32 = True")


    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Create output and experiment directory if needed
    task_prefix = 'clm' if not args.prefix_lm else 'plm'
    experiment_name = (f"{task_prefix}_{args.model_size}_{args.optimizer}" 
                        if args.experiment_name is None else args.experiment_name)
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
        os.makedirs(os.path.join(args.output_dir, experiment_name), exist_ok=True)
    accelerator.wait_for_everyone()

    # Initialize config.    
    # elif args.model_size == "1B":
    model_size_config = dict(hidden_size=2048,
                            num_hidden_layers=24,
                            num_attention_heads=16,
                            num_key_value_heads=16,
                            intermediate_size=4096)

    # download model weights and config files.
    config = LlamaConfig()
    config.update(model_size_config)
    logger.info(f"Model config: {config.to_json_string()}")
    
    # Load tokenizer.
    tokenizer = AutoTokenizer.from_pretrained(args.custom_tokenizer_path, use_fast=True)
    
    # Update config with vocab size.
    config.vocab_size = len(tokenizer.vocab)
    
    # Initialize from pretrained LLaMa model.  
    model = LlamaForCausalLM(config)

    # In case tokenizer has extra tokens.
    prev_shape = model.model.embed_tokens.weight.size()
    model.resize_token_embeddings(len(tokenizer))
    logger.info(f"Resized word embeddings from: {prev_shape} to: {model.model.embed_tokens.weight.size()}")

    assert model.model.embed_tokens.weight.size(1) % 8 == 0, f"embed_tokens must be divisible by 8."
    assert model.model.embed_tokens.weight.size() == model.lm_head.weight.size(), \
          (f"embed_tokens {model.model.embed_tokens.weight.size()} "
           f"and lm_head {model.lm_head.weight.size()} shapes must be the same.")

    # Create dataloaders.
    (train_dl,valid_dl,train_ds,valid_ds) = create_dummy_dataloaders(args.dataset_name,
                                                                        tokenizer,
                                                                        args.block_size,
                                                                        args.per_device_train_batch_size,
                                                                        packed_inputs=False,
                                                                        prefix_lm=False)

    dataset_length = len(train_ds)
    
    # Total batch size.
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info(f"Total batch size: {total_batch_size}")

    # Initialize loss func. Using torch.nn.CrossEntropyLoss().
    # loss_func = get_lm_loss_func(apex=False, flash_attn=False)
    class LMLossTorch:
        def __init__(self):
            self.loss_fct = torch.nn.functional.cross_entropy
            logger.info("Using PyTorch cross_entropy")

        def compute(self, logits, labels, loss_mask=None, z_loss:float=None, ignore_index=-100):
            # Ignore prediction for the last token.
            shift_logits = logits[...,:-1,:].contiguous()
            shift_labels = labels[...,1:].contiguous().long()  
            if loss_mask is not None:
                loss_mask = loss_mask[...,:-1].contiguous().bool()
                shift_labels.masked_fill_(~loss_mask, ignore_index)
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)
            
            # args: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0
            loss = self.loss_fct(shift_logits, shift_labels, ignore_index=ignore_index)
            if z_loss is not None:
                log_z = torch.logsumexp(shift_logits[shift_labels!=ignore_index], dim=-1)**2
                z_loss_val = z_loss*log_z.mean()
                loss += z_loss_val
                return loss, z_loss_val
            return loss, None
    
    loss_func = LMLossTorch()

    # Calculate total number training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dl) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    logger.info(f"num_update_steps_per_epoch (before prepare) : {num_update_steps_per_epoch}")
    logger.info(f"max_train_steps (before prepare) : {args.max_train_steps}")
    
    # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
    total_scheduler_steps = int(args.max_train_steps*accelerator.num_processes)
    if (args.num_warmup_steps is not None) and (args.num_warmup_fraction is not None):
        raise ValueError("Only one of num_warmup_steps or num_warmup_fraction can be specified.")
    if args.num_warmup_steps is not None:
        scheduler_warmup_steps = args.num_warmup_steps
    elif args.num_warmup_fraction is not None:
        scheduler_warmup_steps = int(args.num_warmup_fraction * total_scheduler_steps)
        
    # prepare model first in FSDP.
    model = accelerator.prepare(model)

    def get_optim_param_groups(model,  weight_decay, no_decay):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        decay_names = [n for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
        no_decay_names = [n for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
        return optimizer_grouped_parameters, no_decay_names, decay_names
    optimizer_grouped_parameters, no_decay_names, decay_names = get_optim_param_groups(model, args.weight_decay, 
                                                                                        no_decay=["embed", "bias", "norm.weight"])
    
    newline = '\n'
    logger.info(f"No decay: {newline.join(no_decay_names)}")
    logger.info(f"Decay: {newline.join(decay_names)}")                                                                

    # NOTE: Fused AdamW didn't work with FSDP.
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, 
                                    lr=args.learning_rate,
                                    betas=(0.9,0.95),
                                    eps=1e-5)          
                             
    def _get_cosine_one_cycle_lr_lambda(
        current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_lr_fraction = 0.1,
    ):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))  
        scale_term = (1 - min_lr_fraction)
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return (math.cos(math.pi * progress)+1) * 0.5 * scale_term + min_lr_fraction

    def get_cosine_one_cycle_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_fraction=0.1):
        lr_lambda = partial(
            _get_cosine_one_cycle_lr_lambda,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            min_lr_fraction=min_lr_fraction
        )
        return LambdaLR(optimizer, lr_lambda, last_epoch=-1)                             
    lr_scheduler = get_cosine_one_cycle_scheduler(optimizer, 
                                                    num_warmup_steps=scheduler_warmup_steps, 
                                                    num_training_steps=total_scheduler_steps,
                                                    min_lr_fraction=args.min_lr_fraction)  

    # Prepare the remaining objects with our `accelerator`.
    optimizer, train_dl, valid_dl, lr_scheduler = accelerator.prepare(
        optimizer, train_dl, valid_dl, lr_scheduler
    )

    if isinstance(lr_scheduler, accelerate.scheduler.AcceleratedScheduler):
        logger.info(f"Using AcceleratedScheduler")

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    # Based on dataloader preparation strategy, for IterableDataset batches will be dispatched from main process.
    if args.total_dataset_tokens is None:
        num_update_steps_per_epoch = math.ceil(len(train_dl) / args.gradient_accumulation_steps)
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        logger.info(f"num_update_steps_per_epoch (after prepare) : {num_update_steps_per_epoch}")
        logger.info(f"max_train_steps (after prepare) : {args.max_train_steps}")

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    experiment_config = vars(args)
    # TensorBoard cannot log Enums, need the raw value
    experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
    tracker_project_name = Path(args.output_dir).parent.name
    accelerator.init_trackers(tracker_project_name, experiment_config)
    accelerator.log({"model_config": config.to_json_string()})


    # Train!
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {dataset_length}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0
    starting_epoch = 0

    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()

        # Batches to be used for sophia hessian estimation.
        total_loss = 0
        step_total_loss = 0
        step_total_zloss = 0
        ema_loss = None
        for step, batch in enumerate(train_dl):

            # forward pass: `input_ids`, `decoder_causal_attention`, `decoder_segment_ids`
            batch['input_ids'] = batch.pop(args.input_ids_key)
            labels = batch["input_ids"].clone()
            loss_mask = batch.pop('decoder_loss_mask', None)
            _ = batch.pop('decoder_causal_attention', None)
            # TODO: Check if this is faster than manual block diagonal attention bias.
            if args.multipacked_inputs:
                batch['input_ids'] = batch['input_ids'].view(-1).unsqueeze(0)
                labels = labels.view(-1).unsqueeze(0)
                loss_mask = loss_mask.view(-1).unsqueeze(0)
                # TODO: Disable for now. May be affecting speed.
                # batch['position_ids'] = batch['position_ids'].view(-1).unsqueeze(0)
                offset = torch.cat([torch.tensor([0], device=batch['decoder_segment_ids'].device), 
                                    batch['decoder_segment_ids'][:-1,-1].cumsum(0)]).view(-1,1)
                batch['attention_mask'] = (batch['decoder_segment_ids'] + offset).view(-1).unsqueeze(0)
                                
            outputs = model(**batch)
            # compute loss: `logits`, `labels` `decoder_loss_mask`
            loss, _ = loss_func.compute(outputs.logits, labels, loss_mask, 
                                                 z_loss=args.z_loss, ignore_index=args.ignore_index)

            # We keep track of the loss at each epoch
            accelerator.backward(loss)
            
            if (step + 1) % args.gradient_accumulation_steps == 0:
                # No effect with deepspeed, can keep. Use deepspeed config to enable it.
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), args.gradient_clipping)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

                step_total_loss = step_total_loss / args.gradient_accumulation_steps
                step_total_zloss = step_total_zloss / args.gradient_accumulation_steps
                               
    logger.info("Training complete!")

if __name__ == "__main__":
    main()
"""
Create test dataloader
"""

"""
Prepare dummy tokenized dataset with 1% of data.
"""
from itertools import chain
import functools
from datasets import load_dataset
from transformers import default_data_collator
from torch.utils.data import DataLoader
import numpy as np

def tokenize_function(examples,tokenizer,text_column_name):
    # no attention mask needed.
    return tokenizer(examples[text_column_name], return_attention_mask=False, return_token_type_ids=False)

def group_texts(examples, block_size, packed_inputs, prefix_lm):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    if packed_inputs:
        result["decoder_segment_ids"] = [[1]*len(l) for l in result["input_ids"]]
    if prefix_lm:
        result["decoder_causal_attention"] = [[0]*len(l) for l in result["input_ids"]]
    return result

def create_dummy_dataloaders(dataset_name, tokenizer, sequence_length, batch_size, packed_inputs=True, prefix_lm=False):

    raw_datasets = load_dataset(dataset_name)
    raw_datasets["validation"] = load_dataset(
        dataset_name,
        split=f"train[:1%]",
    )
    # Use dummy dataset for benchmarking set to 1%.
    raw_datasets["train"] = load_dataset(
        dataset_name,
        split=f"train[:1%]",
    )

    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    tokenized_datasets = raw_datasets.map(
        functools.partial(tokenize_function, tokenizer=tokenizer, text_column_name=text_column_name),
        batched=True,
        num_proc=None,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on dataset",
    )

    lm_datasets = tokenized_datasets.map(
        functools.partial(group_texts, block_size=sequence_length, packed_inputs=packed_inputs, prefix_lm=prefix_lm),
        batched=True,
        num_proc=None,
        load_from_cache_file=True,
        desc=f"Grouping texts in chunks of {sequence_length}",
    )
    
    lm_datasets = lm_datasets.rename_column("input_ids", "targets")

    train_dataloader = DataLoader(
        lm_datasets["train"], shuffle=False, collate_fn=default_data_collator, batch_size=batch_size
    )
    eval_dataloader = DataLoader(
        lm_datasets["validation"], collate_fn=default_data_collator, batch_size=batch_size
    )
    return train_dataloader, eval_dataloader, lm_datasets["train"], lm_datasets["validation"]

@KeremTurgutlu KeremTurgutlu changed the title Torch 2.1 compile + FSDP (mixed precision) + LlamaForCausalLM RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Torch 2.1 compile + FSDP (mixed precision) + LlamaForCausalLM: RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Oct 15, 2023
@ezyang ezyang added high priority oncall: distributed Add this issue/PR to distributed oncall triage queue module: fsdp labels Oct 16, 2023
@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 16, 2023
@desertfire
Copy link
Contributor

@wconstab , assigned to you temporarily and feel free to assign to the right owner.

@jon-chuang
Copy link
Collaborator

jon-chuang commented Oct 26, 2023

Previous issue possibly due to fsdp + mixed precision + dynamo + autograd (assigning gradient of incorrect dtype): #110797

@jon-chuang
Copy link
Collaborator

jon-chuang commented Oct 26, 2023

Actually, couldn't this be the root cause of it all? #111794. Not handling autocast manager setup/unwind appropriately when using DDP. Not sure if this bleeds into FSDP.


Could you confirm if the split_module pass is used for FSDP torch.compile @wconstab?

EDIT: turns out the answer is no, but does FSDP introduce its own graph breaks?

@jon-chuang
Copy link
Collaborator

I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.

I still see autocast in your stacktrace. Though you claim to have disabled it, I think that some of the model code has been decorated with autocast.

@wconstab
Copy link
Contributor

wconstab commented Oct 26, 2023

Could you confirm if the split_module pass is used for FSDP torch.compile @wconstab?

The fx graph splitter pass is not used at all for FSDP. FSDP's graph breaks happen more implicitly: Dynamo sees fsdp's python code and is configured to graph-break on it. In the ddp case, there is not really any DDP code running during forward, and the reason to insert the graph-breaks only becomes apparent when you consider the comm operations that happen during backward. So we had to take another approach to add the graph-breaks, using the fx pass.

@wconstab wconstab reopened this Oct 26, 2023
@KeremTurgutlu
Copy link
Author

I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.

I still see autocast in your stacktrace. Though you claim to have disabled it, I think that some of the model code has been decorated with autocast.

Sorry, I only included one of the stacktraces, in this case it happens to be the autocasted run but same error happened without it as well.

@jon-chuang
Copy link
Collaborator

@KeremTurgutlu, may I ask what are your settings for the supposedly non-autocasted run for these?

mixed_precision_policy = MixedPrecision(param_dtype=torch.bfloat16, 
                                        reduce_dtype=torch.bfloat16, 
                                        buffer_dtype=torch.float32)

logger.info(f"FSDP Mixed Precision Policy: {accelerator.state.fsdp_plugin.mixed_precision_policy}")
logger.info(f"Native AMP is enabled: {accelerator.native_amp}")

Could you share the logs so I can confirm?

@voznesenskym
Copy link
Collaborator

@wconstab are you still working on this?

@wconstab wconstab removed their assignment Nov 14, 2023
@wconstab
Copy link
Contributor

no, i never got a chance to look into this. unassigning for now

@psinger
Copy link

psinger commented Dec 14, 2023

Seeing the same issue when using combination of FSDP + compile + MP.

Using amp autocast I don't see these issues, but whenever I add the MixedPrecision setting, this error occurs.

@gouchangjiang
Copy link

gouchangjiang commented Dec 16, 2023

mixed_precision_policy = MixedPrecision(param_dtype=torch.bfloat16, 
                                        reduce_dtype=torch.bfloat16, 
                                        buffer_dtype=torch.float32)

logger.info(f"FSDP Mixed Precision Policy: {accelerator.state.fsdp_plugin.mixed_precision_policy}")
logger.info(f"Native AMP is enabled: {accelerator.native_amp}")

Could you share the logs so I can confirm?

Hi @jon-chuang, I can confirm that even using the setting 'param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32' you given, the same error occurs. Please see logs following:
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return forward_call(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
unshard_fn(state, handle)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
args, kwargs = _pre_forward(
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
return forward_call(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
ran_pre_unshard = handle.pre_unshard()
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
unshard_fn(state, handle)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
args, kwargs = _pre_forward(
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
args, kwargs = _pre_forward(
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
unshard_fn(state, handle)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
ret = self._writeback_orig_params()
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
unshard_fn(state, handle)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
ran_pre_unshard = handle.pre_unshard()
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
return func(*args, **kwargs)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
return forward_call(*args, **kwargs)
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) File "/home/admin/hippo/worker/slave/aop_418921_fsdp_32gpu_debug_20231216224420_3434820_tfjob.ps_1_50_10/tf_worker/tmp_1007/taoscale/models/transformer.py",line 273, in forward

@gouchangjiang
Copy link

additional info: this issue disappears on pytorch 2.0.0 and pytorch 2.1.2

@tangjiasheng
Copy link

Hit similar error while use FSDP and compile together.
attempting to assign a gradient with dtype 'float' to a tensor with dtype 'c10::Half'. Please ensure that the gradient and the tensor have the same dtype.

If I set dtype of mixed precision to fp32, the error turns to tensor size mismatch:
torch._dynamo.exc.InternalTorchDynamoError: attempting to assign a gradient of size '[1828962]' to a tensor of size '[3657924]'. Please ensure that the gradient and the tensor are the same size
as I set parallelism with 2 cards. 1828962*2 = 3657924

It seems the gradient (FlatTensor) in FSDP mode is not well caught by compile.

My torch version is 2.2 and the way I use FSDP is like this:

FSDP(unet, mixed_precision=fpSixteen, auto_wrap_policy=my_size_based_auto_wrap_policy, device_id=torch.cuda.current_device(),
            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, use_orig_params=True)

@psinger
Copy link

psinger commented Feb 15, 2024

@tangjiasheng did you manage to find a solution for the size mismatch? Running into the same error, and this is the only post referencing it I could find.

@tangjiasheng
Copy link

@tangjiasheng did you manage to find a solution for the size mismatch? Running into the same error, and this is the only post referencing it I could find.

Sorry, the answer is NO...

@yf225 yf225 self-assigned this Feb 27, 2024
@Skylion007
Copy link
Collaborator

We are also hitting this issue with SHARD_GRAD_OP. Specifically, we are hitting it on the first eval step, after a training step has run on training Stable Diffusion 2. If we never do evaluation / saving the model seems to run fine. Hopefully this helps narrow it down a bit @yf225 .

I can confirm that this issue is present both on 2.2 as well as on nightly as of last week.

@psinger
Copy link

psinger commented Mar 6, 2024

My current hypothesis is that these errors happen due to OOM, but no proper OOM error is present but the ones discussed in this thread.

@zejun-chen
Copy link
Contributor

Also met the issue for bf16 training and fp16 training.

[rank7]:   File "/home/bduser/zejun/pytorch/torch/_dynamo/symbolic_convert.py", line 754, in step
[rank7]:     getattr(self, inst.opname)(inst)
[rank7]:   File "/home/bduser/zejun/pytorch/torch/_dynamo/symbolic_convert.py", line 1292, in LOAD_ATTR
[rank7]:     result = BuiltinVariable(getattr).call_function(
[rank7]:   File "/home/bduser/zejun/pytorch/torch/_dynamo/variables/builtin.py", line 697, in call_function
[rank7]:     result = handler(tx, *args, **kwargs)
[rank7]:   File "/home/bduser/zejun/pytorch/torch/_dynamo/variables/builtin.py", line 1296, in call_getattr
[rank7]:     grapharg.example.grad = torch.zeros(
[rank7]: torch._dynamo.exc.InternalTorchDynamoError: attempting to assign a gradient with dtype 'float' to a tensor with dtype 'c10::Half'. Please ensure that the gradient and the tensor
 have the same dtype

@tangjiasheng
Copy link

ShardingStrategy.

Also fail with ShardingStrategy.FULL_SHARD.

@Skylion007
Copy link
Collaborator

I encounter this issue when I run an evaluation step after a training step. If I do evaluation before any training steps, torch compile works fine and runs training until the first evaluation step after the torch compile backward passes.

Seems like this has to do with it not respecting no_grad in the eval pass or different mixed precision / autocasting in that pass.

@0-hero
Copy link

0-hero commented Apr 6, 2024

My current hypothesis is that these errors happen due to OOM, but no proper OOM error is present but the ones discussed in this thread.

This is the case for me. Reducing batch size worked

@jasonkrone
Copy link

Also hitting this issue with torch version: 2.3.0a0+6ddf5cf85e.nv24.04

As Skylion007 mentioned, this issue occurs for me if I do train before eval, but it goes away if I do eval before train.

@OrenLeung
Copy link
Contributor

OrenLeung commented Jun 24, 2024

any updates?

running into RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float' with FSDP + bf16 amp autocast + torch.compile

@anijain2305
Copy link
Contributor

#134614 might be able to fix this issue.

@yf225 yf225 assigned anijain2305 and unassigned yf225 Aug 28, 2024
@yf225
Copy link
Contributor

yf225 commented Aug 28, 2024

#134614 should fix this issue. Please reopen the issue if it's not the case.

@yf225 yf225 closed this as completed Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests