Skip to content

Commit

Permalink
Update mask name for THD attention
Browse files Browse the repository at this point in the history
  • Loading branch information
cuichenx authored and ko3n1g committed Jun 28, 2024
1 parent 69d7d5b commit 16b9fdd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
51 changes: 34 additions & 17 deletions megatron/core/transformer/custom_layers/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class TENorm:

# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(
cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5,
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
Expand Down Expand Up @@ -148,9 +151,9 @@ def __init__(
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
Expand Down Expand Up @@ -258,9 +261,9 @@ def __init__(
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
Expand All @@ -285,7 +288,7 @@ def forward(self, x):
return out, None

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(
)

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
Expand Down Expand Up @@ -378,7 +381,7 @@ def __init__(
)

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 1, bias not sharded """
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
Expand Down Expand Up @@ -469,15 +472,15 @@ def __init__(
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout
if attention_dropout is None
else attention_dropout,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
Expand Down Expand Up @@ -519,6 +522,14 @@ def forward(
value = value.as_strided(value.shape, key.stride())

if self.te_forward_mask_type:
if qkv_format == 'thd' and _te_version >= packaging.version.Version("1.7.0"):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True, so the only
# acceptable mask types are `padding_causal` and `padding`. These do not necessarily indicate
# there are padded tokens in the sequence.
if attn_mask_type == AttnMaskType.causal:
attn_mask_type = AttnMaskType.padding_causal
elif attn_mask_type == AttnMaskType.no_mask:
attn_mask_type = AttnMaskType.padding
core_attn_out = super().forward(
query,
key,
Expand All @@ -528,7 +539,13 @@ def forward(
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs,)
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
**packed_seq_kwargs,
)

if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
Expand Down
1 change: 1 addition & 0 deletions megatron/core/transformer/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class AttnMaskType(enum.Enum):
padding = 1
causal = 2
no_mask = 3 # only used for TE
padding_causal = 4 # only used for thd attention

0 comments on commit 16b9fdd

Please sign in to comment.