Skip to content

Commit

Permalink
changed all instances of torch.concat to torch.cat
Browse files Browse the repository at this point in the history
  • Loading branch information
mariebiscuit committed Jul 17, 2023
1 parent 303d7be commit edfa39b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
)

# Combined k/v into [b * sk, 2, np, hn].
kv = torch.concat([key_layer, value_layer], dim=1)
kv = torch.cat([key_layer, value_layer], dim=1)

output = self.flash_kv_fn(
query_layer,
Expand All @@ -553,7 +553,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)
qkv = torch.cat([query_layer, key_layer, value_layer], dim=1)

output = self.flash_qkv_fn(
qkv,
Expand Down

0 comments on commit edfa39b

Please sign in to comment.