Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query device name from pytorch if only device index is given #500

Closed
wants to merge 2 commits into from

Conversation

dvrogozh
Copy link

@dvrogozh dvrogozh commented Jul 15, 2024

Fixes: #499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this case both PyTorch and Safetensors were returning 'cuda:N' by default. This is causing runtime failures if user actually runs something on non-cuda device and does not have cuda at all. Recently this was addressed on PyTorch side by 1: starting from PyTorch 2.5 calling 'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device index is given, Safetensors will query and return device calling 'torch.device(N)'. This change is backward compatible since this call would return 'cuda:N' on PyTorch <=2.4 which aligns with previous Safetensors behavior.

CC: @guangyey @jgong5 @faaany @muellerzr @SunMarc @Narsil

@dvrogozh
Copy link
Author

@Narsil : can you, please, help to review?

@dvrogozh
Copy link
Author

Rebased on top of latest main. @Narsil, @ArthurZucker, @SunMarc : can you, please, help to review?

@dvrogozh
Copy link
Author

Resolved conflict with 2331974. @Narsil @muellerzr @SunMarc : can this PR, please, be reviewed?

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2024

@dvrogozh Can you stop calling it a bug everywhere, since it's not a bug, it's breaking change you are proposing, that you introduced in torch==2.5.

The new behavior may be more user friendly on non cuda accelerators, it is nonetheless a breaking change and should be treated as such.

This PR introduces a dependency on torch itself since we would be dependent on torch to produce the correct string, therefore I cannot take the code as-is.

The raison d'être of this code is to provide simple validation so safe_open(...., device="xxx") is rejected with an appropriate error message (before doing work and sending the invalid values to torch.).

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2024

I made a cleaner implementation imho: #509

Can you check that it fixes your issue ? You're also more than welcome to steal the code from said PR so we can merge your PR instead of mine, so you get credit.

Narsil and others added 2 commits July 31, 2024 07:43
Co-authored-by: Dmitry Rogozhkin <[email protected]>
Fixes: huggingface#499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this
case both PyTorch and Safetensors were returning 'cuda:N' by default.
This is causing runtime failures if user actually runs something on
non-cuda device and does not have cuda at all. Recently this was
addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device
index is given, Safetensors will query and return device calling
'torch.device(N)'. This change is backward compatible since this call
would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
Safetensors behavior.

See[1]: pytorch/pytorch#129119

Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh
Copy link
Author

I made a cleaner implementation imho: #509. Can you check that it fixes your issue ?

Unfortunately it does not. Fails with "RuntimeError: Invalid device string: '0'". See pytorch part of the stack at #509 (review)

@dvrogozh
Copy link
Author

@Narsil : I reworked my PR on top of your proposal made in #509 to abstract device string parsing (add fn parse_device). And I preserved original logic from my PR which is to return output of torch.device(N) instead of cuda:N. This should work for any pytorch version since on <2.5 torch.device(N) it will return cuda:N and on >=2.5 it will return whatever is configured on pytorch side. The difference with #509 is that mine PR returns fully qualified pytorch device name while yours returns just index. Apparently somewhere index is not enough, but further debug is needed to say exactly where (pytorch stack does not tell exact place, it's somewhere in safetensors' rust code).

@Narsil
Copy link
Collaborator

Narsil commented Aug 1, 2024

The logic you kept is the logic I want to get rid of. It's still wrong to depend on torch internals (here the string representation of resolved device) on that specific part.

@dvrogozh
Copy link
Author

dvrogozh commented Aug 1, 2024

The logic you kept is the logic I want to get rid of.

I am fine with this as soon as your change address the problem. I gave a try to modified version of #509, it works for me now. We can proceed with yours variant if you believe it's more aligned with the safetensors design.

@Narsil
Copy link
Collaborator

Narsil commented Aug 1, 2024

Perfect done.

Thanks again for raising awareness about the upcoming new behavior !

@Narsil
Copy link
Collaborator

Narsil commented Aug 1, 2024

Superseeded by #509

@Narsil Narsil closed this Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants