-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a FSDP to YaFSDP migration guide (#9)
- Loading branch information
1 parent
71f3ed3
commit cb11293
Showing
2 changed files
with
67 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# FSDP to YaFSDP migration guide | ||
|
||
Sharding a model with FSDP might look similar to the example below: | ||
|
||
```python | ||
model: LlamaForCausalLM = ... | ||
|
||
FSDP( | ||
model, | ||
sharding_strategy=sharding_strategy, | ||
auto_wrap_policy=partial( | ||
lambda_auto_wrap_policy, | ||
lambda_fn=lambda m: ( | ||
m is model.model.embed_tokens | ||
or m in model.model.layers | ||
or m is model.lm_head | ||
), | ||
), | ||
mixed_precision=MixedPrecision(param_dtype=param_dtype, reduce_dtype=torch.float32), | ||
sync_module_states=sync_module_states, | ||
param_init_fn=param_init_fn, | ||
device_id=device, | ||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | ||
forward_prefetch=True, | ||
use_orig_params=True, | ||
) | ||
``` | ||
|
||
An equivalent YaFSDP sharding looks like this: | ||
|
||
```python | ||
model: LlamaForCausalLM = ... | ||
|
||
YaFSDP( | ||
model, | ||
zero_stage={ | ||
"ShardingStrategy.FULL_SHARD": 3, | ||
"ShardingStrategy.SHARD_GRAD_OP": 2 | ||
}[sharding_strategy], | ||
modules_to_wrap_with_names=[ | ||
(model.model.embed_tokens, "model.embed_tokens"), | ||
*((m, f"model.layers.{i}") for i, m in enumerate(model.model.layers)), | ||
(model.lm_head, "lm_head") | ||
], | ||
output_layer_module_with_name=(model.norm, "model.norm"), | ||
layer_norm_module_cls=LlamaRMSNorm, | ||
param_dtype=param_dtype, | ||
sync_module_states=sync_module_states, | ||
param_init_fn=param_init_fn, | ||
device_id=device, | ||
gradient_accumulation_steps=gradient_accumulation_steps, | ||
) | ||
``` | ||
|
||
A major interface difference is in `auto_wrap_policy` as it is replaced with 3 | ||
arguments: | ||
|
||
- `modules_to_wrap_with_names` — an explicit list of modules to shard (with | ||
their names to be used in state dict) | ||
- `output_layer_module_with_name` — the first layer after all transformer | ||
blocks, which contains only layer norm parameters (it typically does) (with | ||
its name) | ||
- `layer_norm_module_cls` — type of layer norm layers | ||
|
||
Also YaFSDP requires the number of gradient accumulation steps to be explicitly | ||
provided. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters