Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba Architecture #1157

Merged
merged 40 commits into from
Mar 10, 2024
Merged

Add Mamba Architecture #1157

merged 40 commits into from
Mar 10, 2024

Conversation

haileyschoelkopf
Copy link
Contributor

@haileyschoelkopf haileyschoelkopf commented Feb 26, 2024

closes #1148

This PR adds Mamba to NeoX, along with flags for turning on/off the selective scan + conv1d + full mamba_inner_fn kernels.

For now, does not support parallelism, but want to investigate adding Tensor Parallel to this.

@haileyschoelkopf haileyschoelkopf marked this pull request as draft February 26, 2024 14:56
@haileyschoelkopf
Copy link
Contributor Author

This seems to train well without parallelism, but am having bugs in a conversion script I wrote (gibberish output). I'll be checking for differences in output between a single instantiated layer for this versus the mamba_ssm module to try to diagnose further

@haileyschoelkopf
Copy link
Contributor Author

Worked after state-spaces/mamba#211 !

Got performance from a 160m (trained with Pythia config, untied embed + unembed) on-par with the Mamba-130m results in paper! image

Will cleanup code slightly and add sample configs, then mark ready for review. This also pairs with a DeeperSpeed PR I'll make that should allow for holding specified parameters in fp32 despite Deepspeed trying to cast everything to 16 bit.

I want to check out adding tensor parallelism for Mamba too, but will do that later.

@Quentin-Anthony
Copy link
Member

Worked after state-spaces/mamba#211 !

Got performance from a 160m (trained with Pythia config, untied embed + unembed) on-par with the Mamba-130m results in paper! image

Will cleanup code slightly and add sample configs, then mark ready for review. This also pairs with a DeeperSpeed PR I'll make that should allow for holding specified parameters in fp32 despite Deepspeed trying to cast everything to 16 bit.

I want to check out adding tensor parallelism for Mamba too, but will do that later.

Awesome! Great work. I have some TP ideas that we can discuss on discord.

@haileyschoelkopf haileyschoelkopf marked this pull request as ready for review March 7, 2024 14:49
@haileyschoelkopf
Copy link
Contributor Author

Ready for initial review!

Pairs with EleutherAI/DeeperSpeed#61 , which I'd appreciate feedback on if the approach there is acceptable.

@Quentin-Anthony
Copy link
Member

Note for the future: In addition to the attention_config, we should add a block_config where the user can choose on a per-block basis where to place individual blocks for MLP(and its variants like MoE)/Mamba/Attention(and its variants). The attention_config would then just allow the user to choose between attention variants for any attention blocks in the broader block_config. A similar mlp_config could be useful.

E.g.

"block_config": ["mamba", "attention", "mamba", "attention", "mlp"],
"attention_config": [[["flash"], 2]],
"mlp_config": [[["moe"], 1],

One drawback from this strategy is that we're adding a lot of annoying book-keeping for the user.

It's already a bit confusing that we're putting the attention-free mamba block under attention_config, but I think it's fine for this PR.

@Quentin-Anthony
Copy link
Member

In the future, we should add support for the triton RMSNorm kernel introduced by mamba. Noting here and adding a TODO for later.

https://github.com/state-spaces/mamba/blob/v1.2.0/mamba_ssm/ops/triton/layernorm.py

@Quentin-Anthony Quentin-Anthony merged commit 6809bbc into main Mar 10, 2024
2 checks passed
@Quentin-Anthony Quentin-Anthony deleted the mamba-neox branch March 10, 2024 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add basic Mamba block
2 participants