Skip to content

Commit

Permalink
Add a docstring for YaFSDP class (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
antony-frolov authored Jun 25, 2024
1 parent 179dcec commit 71f3ed3
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions ya_fsdp/ya_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,59 @@ class ReusableBufferViewState(Enum):


class YaFSDP(nn.Module):
"""A wrapper for sharding module parameters with YaFSDP.
Args:
module (nn.Module):
Module to be wrapped.
zero_stage (int):
Type of sharding to apply (analogous to `sharding_strategy` in
FSDP). `3` is for ZeRO-3 like sharding, 2 is for `ZeRO-2 like
sharding.
param_dtype (torch.dtype):
Dtype to use during forward pass.
modules_to_wrap_with_names (list[tuple[nn.Module, str]]):
Tuple of modules to shard with corresponding names for state_dict
(analogous to `auto_wrap_policy` in FSDP).
layer_norm_module_cls (Type[nn.Module]):
Class of layer norm layers.
gradient_accumulation_steps (int):
Number of gradient accumulation steps.
data_parallel_process_group (dist.ProcessGroup | None, optional):
Process group to shard parameters across.
Defaults to None.
intra_node_data_parallel_process_group (dist.ProcessGroup | None, optional):
A group of processes which share the node host and `data_parallel_process_group`.
Defaults to None.
model_parallel_process_group (dist.ProcessGroup | None, optional):
A group of tensor parallel processes. Defaults to None.
all_reduce_grads_across_model_parallel_group (bool, optional):
If True, layer norm parameters are reduced across `model_parallel_process_group`.
Defaults to False.
bit16_reduce_scatter (bool, optional):
If True, reduce scatter is performed in bfloat16.
Defaults to False.
bit32_acc_for_bit16_reduce_scatter (bool, optional):
If True, uses a custom kernel for float32 accuracy for bfloat16 reduce scatter. Defaults to False.
hpz_first_layers_num (int, optional):
Number for layers to apply HPZ to.
Defaults to 0.
output_layer_module_with_name (tuple[nn.Module, str] | None, optional):
Instance of output layer with a corresponding name. Output layer is
expected to contain only layer norm parameters.
Defaults to None.
sync_module_states (bool, optional):
If True, module states are synced across
`intra_node_data_parallel_process_group` before sharding.
Defaults to False.
param_init_fn (Callable | None, optional):
A function to initialize modules before sharding (and before syncing).
Defaults to None.
device_id (int | None, optional):
Device to use for initialization and sharding.
Defaults to None.
"""

def __init__(
self,
module: nn.Module,
Expand Down Expand Up @@ -369,6 +422,7 @@ def clip_grad_norm_(self, max_norm: float | int, norm_type: float | int = 2.0) -

def local_state_dict(self):
state_dict = super().state_dict()
state_dict = {key.removeprefix("_"): value for key, value in state_dict.items()}
# Save meta on chief processes
if self._data_parallel_process_group.rank() == 0:
state_dict["meta"] = self._meta_info
Expand All @@ -388,6 +442,7 @@ def state_dict(self):

def load_state_dict(self, state_dict):
state_dict.pop("meta", None)
state_dict = {f"_{key}": value for key, value in state_dict.items()}
super().load_state_dict(state_dict)

def get_files_to_load(self):
Expand Down Expand Up @@ -888,6 +943,13 @@ def __init__(

self.register_full_backward_pre_hook(self.super_tensor.get_backward_pre_hook())

def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.root_module, name)

def pre_forward(self, args, kwargs):
all_inputs = list(args) + [kwargs[key] for key in kwargs]
gate_result = self.super_tensor.pre_forward(all_inputs)
Expand Down

0 comments on commit 71f3ed3

Please sign in to comment.