Skip to content

Commit

Permalink
Add a FSDP to YaFSDP migration guide (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
antony-frolov committed Jun 25, 2024
1 parent 71f3ed3 commit cb11293
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
66 changes: 66 additions & 0 deletions docs/migration_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# FSDP to YaFSDP migration guide

Sharding a model with FSDP might look similar to the example below:

```python
model: LlamaForCausalLM = ...

FSDP(
model,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy,
lambda_fn=lambda m: (
m is model.model.embed_tokens
or m in model.model.layers
or m is model.lm_head
),
),
mixed_precision=MixedPrecision(param_dtype=param_dtype, reduce_dtype=torch.float32),
sync_module_states=sync_module_states,
param_init_fn=param_init_fn,
device_id=device,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
forward_prefetch=True,
use_orig_params=True,
)
```

An equivalent YaFSDP sharding looks like this:

```python
model: LlamaForCausalLM = ...

YaFSDP(
model,
zero_stage={
"ShardingStrategy.FULL_SHARD": 3,
"ShardingStrategy.SHARD_GRAD_OP": 2
}[sharding_strategy],
modules_to_wrap_with_names=[
(model.model.embed_tokens, "model.embed_tokens"),
*((m, f"model.layers.{i}") for i, m in enumerate(model.model.layers)),
(model.lm_head, "lm_head")
],
output_layer_module_with_name=(model.norm, "model.norm"),
layer_norm_module_cls=LlamaRMSNorm,
param_dtype=param_dtype,
sync_module_states=sync_module_states,
param_init_fn=param_init_fn,
device_id=device,
gradient_accumulation_steps=gradient_accumulation_steps,
)
```

A major interface difference is in `auto_wrap_policy` as it is replaced with 3
arguments:

- `modules_to_wrap_with_names` — an explicit list of modules to shard (with
their names to be used in state dict)
- `output_layer_module_with_name` — the first layer after all transformer
blocks, which contains only layer norm parameters (it typically does) (with
its name)
- `layer_norm_module_cls` — type of layer norm layers

Also YaFSDP requires the number of gradient accumulation steps to be explicitly
provided.
2 changes: 1 addition & 1 deletion ya_fsdp/ya_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class YaFSDP(nn.Module):
bit32_acc_for_bit16_reduce_scatter (bool, optional):
If True, uses a custom kernel for float32 accuracy for bfloat16 reduce scatter. Defaults to False.
hpz_first_layers_num (int, optional):
Number for layers to apply HPZ to.
Number of layers to apply HPZ to.
Defaults to 0.
output_layer_module_with_name (tuple[nn.Module, str] | None, optional):
Instance of output layer with a corresponding name. Output layer is
Expand Down

0 comments on commit cb11293

Please sign in to comment.