Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch 1.12 and flash-attn==0.2.8 are not compatible. #27

Open
heheda12345 opened this issue Apr 10, 2024 · 0 comments
Open

PyTorch 1.12 and flash-attn==0.2.8 are not compatible. #27

heheda12345 opened this issue Apr 10, 2024 · 0 comments

Comments

@heheda12345
Copy link

Thanks for your great work! I am trying to reproduce the latency tests with the scripts in Dejavu/benchmarks folder. I've installed the recommended PyTorch 1.12 and flash-attn==0.2.8. But these two libraries are not compatible. I get the following error caused by this line in flash attention. It calls get_global_rank that is not available in PyTorch 1.12 and only available in newer PyTorch. What library version should I use to reproduce the results?

p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
AttributeError: module 'torch.distributed' has no attribute 'get_global_rank'

Plus, the scripts use a weight called "full.pt". It is not in OPT's huggingface repo. How should I get this file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant