Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for preventing 16-bit cast of marked modules #61

Merged
merged 1 commit into from
Mar 8, 2024

Conversation

haileyschoelkopf
Copy link

This PR allows for one to mark a given model parameter with p._deepspeed_no_cast = True to allow it to not be cast into 16-bit by DeepSpeed.

Used this to train a mamba-160m, keeping A_log and D params in fp32 throughout, which performed on-par with paper's reported results. However, if this change seems like a bad idea or is likely to break things, would be glad to know what to check for.

This should also allow us to keep inv_freq in rotary embeddings in fp32!

Copy link
Member

@Quentin-Anthony Quentin-Anthony left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is acceptable. I'll have to think about how we merge this into upstream DeepSpeed.

@Quentin-Anthony Quentin-Anthony merged commit 6d097be into main Mar 8, 2024
9 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants