Skip to content

Commit

Permalink
fix norm sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed May 2, 2024
1 parent fcf48fe commit bc36642
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __call__(
weight_dtype=cfg.weight_dtype,
name="decoder_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("embed",),
kernel_axes=("norm",),
)(y)
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)

Expand Down

0 comments on commit bc36642

Please sign in to comment.