Skip to content

Commit

Permalink
Merge pull request #637 from google:lizhiyu/miss_sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630121986
  • Loading branch information
maxtext authors committed May 2, 2024
2 parents fcf48fe + bc36642 commit 2ac0af9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __call__(
weight_dtype=cfg.weight_dtype,
name="decoder_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("embed",),
kernel_axes=("norm",),
)(y)
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)

Expand Down

0 comments on commit 2ac0af9

Please sign in to comment.