Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine the logic of device construction when only device index is given #129119

Closed
wants to merge 36 commits into from

Conversation

guangyey
Copy link
Collaborator

@guangyey guangyey commented Jun 20, 2024

Stack from ghstack (oldest at bottom):

Motivation

Before this PR, device construction was cuda type when only a device index was given. It also returns the PrivateUser1 type if a PrivateUser1 type is registered.

>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')

It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.

>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

With this PR, refine the logic to use the currently available device type instead.

Copy link

pytorch-bot bot commented Jun 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129119

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit c1fdb82 with merge base e2e624a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -813,8 +813,14 @@ inline at::Device toDevice(PyObject* obj) {
c10::DeviceType::PrivateUse1,
static_cast<c10::DeviceIndex>(device_index));
}
#ifdef USE_CUDA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated #126646 with details.
Can you change this to get the current accelerator instead?

Copy link
Collaborator Author

@guangyey guangyey Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, using the current accelerator looks better.
And I add XPU to getAccelerator and refine the code. In order to facilitate the review, I separate the code change into two PRs.

@guangyey guangyey requested a review from albanD June 21, 2024 03:47
guangyey added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: a34c42c3cbaf5e492e6baa794251b4bc1dc74c10
Pull Request resolved: #129119
@guangyey guangyey added ciflow/trunk Trigger trunk jobs on your pull request release notes: python_frontend release notes category labels Jun 21, 2024
@gujinghui gujinghui requested a review from EikanWang June 21, 2024 06:15
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic for XPU. It will return 'xpu' device type if PyTorch is built with XPU. And raise an error if PyTorch is built without any device but only accepts a device index. Now, it works well on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'xpu'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='xpu:0')
```



[ghstack-poisoned]
@@ -213,7 +213,8 @@ non-None device argument. To globally change the default device, see also

.. note::
For legacy reasons, a device can be constructed via a single device ordinal, which is treated
as a cuda device. This matches :meth:`Tensor.get_device`, which returns an ordinal for cuda
as a currently available device type (i.e. "cuda" if cuda is available, "xpu" if xpu is available).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we're going to have a few features relying on it as the default (pin memory for example), I think we should make "accelerator" a public concept in this doc.
I think we can:

  • Have one paragraph in https://pytorch.org/docs/stable/torch.html that introduces the concept of accelerator and lists all the current ones (based on the list in c10).
  • Have other docs like this one just mention the "current accelerator" with a link to the paragraph above.

WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Maybe we can take below RFC together with this PR, to figure out an accelerator mechanism both in frontend and backend? #128403

@guangyey let's have a talk next week.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is great to me~

…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
OnlyFor pushed a commit to OnlyFor/pytorch that referenced this pull request Jul 2, 2024
ghstack-source-id: 665d24a5bcbf922c1125c4fa753066bedbe1968e
Pull Request resolved: pytorch#129119
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
guangyey added a commit that referenced this pull request Jul 4, 2024
ghstack-source-id: c5b3829bd33f8b32a0761b9d850630028d7c1d01
Pull Request resolved: #129119
guangyey added a commit that referenced this pull request Jul 4, 2024
ghstack-source-id: 2d88b54f74342a330c1bee53d2dba94428582bbf
Pull Request resolved: #129119
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
…ndex is given"


# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.

[ghstack-poisoned]
@dvrogozh
Copy link
Contributor

Folks, just want to give you a heads up that there was an assumption that this PR will address huggingface/transformers#31941. Unfortunately, this PR does not fully solve it. I believe that PR gives an essential change, but it alone is not enough to get the described case working. This being said, mind that at the moment I don't know what's the remainder of the root cause for 31941. It might be that HF will need another fix somewhere or pytorch - can't say right now.

@gujinghui
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@dvrogozh
Copy link
Contributor

Folks, just want to give you a heads up that there was an assumption that this PR will address huggingface/transformers#31941. Unfortunately, this PR does not fully solve it.

Update on this. The root cause is the bug in HF safetensors library, see huggingface/safetensors#499. This PR indeed provides required prerequisite for the fix. So, thank you for merging @gujinghui.

dvrogozh added a commit to dvrogozh/safetensors that referenced this pull request Jul 15, 2024
Fixes: huggingface#499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this
case both PyTorch and Safetensors were returning 'cuda:N' by default.
This is causing runtime failures if user actually runs something on
non-cuda device and does not have cuda at all. Recently this was
addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device
index is given, Safetensors will query and return device calling
'torch.device(N)'. This change is backward compatible since this call
would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
Safetensors behavior.

See[1]: pytorch/pytorch#129119
Signed-off-by: Dmitry Rogozhkin <[email protected]>
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…en (pytorch#129119)

# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.
Pull Request resolved: pytorch#129119
Approved by: https://github.com/albanD, https://github.com/gujinghui, https://github.com/EikanWang
ghstack dependencies: pytorch#129463, pytorch#129205, pytorch#129363
dvrogozh added a commit to dvrogozh/safetensors that referenced this pull request Jul 25, 2024
Fixes: huggingface#499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this
case both PyTorch and Safetensors were returning 'cuda:N' by default.
This is causing runtime failures if user actually runs something on
non-cuda device and does not have cuda at all. Recently this was
addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device
index is given, Safetensors will query and return device calling
'torch.device(N)'. This change is backward compatible since this call
would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
Safetensors behavior.

See[1]: pytorch/pytorch#129119
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/safetensors that referenced this pull request Jul 30, 2024
Fixes: huggingface#499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this
case both PyTorch and Safetensors were returning 'cuda:N' by default.
This is causing runtime failures if user actually runs something on
non-cuda device and does not have cuda at all. Recently this was
addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device
index is given, Safetensors will query and return device calling
'torch.device(N)'. This change is backward compatible since this call
would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
Safetensors behavior.

See[1]: pytorch/pytorch#129119
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/safetensors that referenced this pull request Jul 31, 2024
Fixes: huggingface#499
Fixes: huggingface/transformers#31941

In some cases only device index is given on querying device. In this
case both PyTorch and Safetensors were returning 'cuda:N' by default.
This is causing runtime failures if user actually runs something on
non-cuda device and does not have cuda at all. Recently this was
addressed on PyTorch side by [1]: starting from PyTorch 2.5 calling
'torch.device(N)' will return current device instead of cuda device.

This commit is making similar change to Safetensors. If only device
index is given, Safetensors will query and return device calling
'torch.device(N)'. This change is backward compatible since this call
would return 'cuda:N' on PyTorch <=2.4 which aligns with previous
Safetensors behavior.

See[1]: pytorch/pytorch#129119

Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: python_frontend release notes category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

7 participants