-
Notifications
You must be signed in to change notification settings - Fork 830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: expected scalar type Float but found Half #170
Comments
It seems some bots left comments to spread virus I think. BE CAREFUL. |
This error shouldn’t occur because, in the CLI demo, the entire model pipeline is loaded using FP16 (by default), so there shouldn’t be an issue with FP32. Can you print the dtype of the pipeline? |
Sure. I set breakpoints and confirmed the pipe dtype is fp16. And I dive into the code, the bug is raised within T5 model (also confirmed as fp16 dtype), just after the "wo" module, hidden_state become fp32 and at the SECOND layer of SA module (after a "wo" module), the layer norm op raise the error (transformer.models.t5.modeling_t5:592).
|
This shouldn't be the case. I haven't encountered this situation for the time being because this part of the replacement is reasonable and shouldn't directly cause errors |
I had the same problem in T5 |
I met the same problem, how to fix it, please. |
This line is the root cause for fp32 conversion. And it looks like it's only affecting fp16, due to the internal logic to handle A quick fix is to convert dtype back after |
System Info / 系統信息
CUDA==11.8
pytorch==2.3.0
diffusers==0.30.1
transformer==4.44.2
apex==0.1
Information / 问题信息
Reproduction / 复现过程
run the demo code:
python inference/cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
I think it is related to T5 model. "wo" is set to keep fp32, when the hidden_states pass "wo" module, it become fp32 dtype, and thus cause the RuntimeError. This error can be fixed by fp32 mode or explicit cast the data type of hidden_states to fp16 after "wo" module.
But I still wonder if this is a common bug or just caused by my corrupted lib dependency?
Expected behavior / 期待表现
model can be run in fp16 mode without error
The text was updated successfully, but these errors were encountered: