Skip to content

Commit

Permalink
Use minimal softmax mask values based on the datatype (#966)
Browse files Browse the repository at this point in the history
* added HF to NeoX 2.0 conversion script with mp and pp sharding

* (1) added missing curly brace to pythial/1-4B config; (2) fixed a bug related to a hardcoded value withing the conversion script (3) fixed possible bugs in the conversion script wrt the mp sharding convention

* fill in minimal possible mask values

* initialize tensor on the target device

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
bentherien and Quentin-Anthony committed Jun 5, 2023
1 parent 7a595f5 commit a6065b4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@


def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
mask_value = torch.finfo(attention_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attention_scores.dtype, device=attention_scores.device)
attention_scores.masked_fill_(ltor_mask, mask_value)
return attention_scores


Expand Down

0 comments on commit a6065b4

Please sign in to comment.