Skip to content
This repository has been archived by the owner on Aug 11, 2022. It is now read-only.

Commit

Permalink
🐛 fix RowParallelLinear bias checkpoing loading
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed May 2, 2022
1 parent 318a676 commit cf93786
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
17 changes: 16 additions & 1 deletion src/neox/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,22 @@ def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p
diff = sum((w1 - w2) ** 2).item()
assert diff < 1e-4, f'The partitions do not match: {key}'

param.data[:] = w1
param.data[:] = (w1 + w2) / 2.


def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
p2: Dict[str, torch.Tensor]):
"""
### Load biases that are partitioned which gets added on reduce
:param param: is the parameter
:param key: is the name of the parameter
:param p1: first partition dictionary
:param p2: second partition dictionary
"""
w1, w2 = p1[key], p2[key]

param.data[:] = w1 + w2


#
Expand Down
4 changes: 2 additions & 2 deletions src/neox/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
with monit.section('Load transformer layer'):
# Attention output transform
checkpoint.merge_params_duplicate(self.attention.output.bias, 'attention.dense.bias', p1, p2)
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

# Attention query, key and value transform
Expand All @@ -379,7 +379,7 @@ def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

# FFN first transform
checkpoint.merge_params_duplicate(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

# Layer norm before FFN
Expand Down

0 comments on commit cf93786

Please sign in to comment.