Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent cuda:0 context initialization when working on another cuda device #124722

Closed
wants to merge 1 commit into from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Apr 23, 2024

Description

Issue description. When user works with "cuda:1" device and compile a model, there is cuda context initialization for device "cuda:0", which can be surprising to the user seeing with nvidia-smi the device 0 utilisation.

Reproduction code:

import torch
from torchvision.models import resnet18

def print_memory_usage():
    for d in [0, 1]:
        stats = torch.cuda.memory_stats(device=d)
        m = stats["allocated_bytes.all.allocated"] + stats["inactive_split_bytes.all.allocated"] + stats["reserved_bytes.all.allocated"]
        print(f"\t- CUDA Device: {d}, allocated + reserved + non-released in MB: {m / 1024 / 1024}")

device = "cuda:1"
model = resnet18()
compiled_model = torch.compile(model)

print("--- Before compiled model to device")
print_memory_usage()

compiled_model.to(device)
x = torch.rand(16, 3, 320, 320, device=device)

print("--- Before compiled model forward")
print_memory_usage()

y = compiled_model(x)

print("--- Before compiled model backward")
print_memory_usage()

y.sum().backward()

print("--- After compiled model backward")
print_memory_usage()

Output:

--- Before compiled model to device
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 0.0
--- Before compiled model forward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 192.966796875
--- Before compiled model backward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 8.044921875    # <--- this should be zero
        - CUDA Device: 1, allocated + reserved + non-released in MB: 2054.27197265625
--- After compiled model backward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 8.044921875    # <--- this should be zero
        - CUDA Device: 1, allocated + reserved + non-released in MB: 5654.61962890625

This PR fixes cuda context initialization init_cuda_context on FakeTensor creation, lazy_init and pattern registrations.

--- Before compiled model to device
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 0.0
--- Before compiled model forward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 192.966796875
--- Before compiled model backward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 2054.31982421875
--- After compiled model backward
        - CUDA Device: 0, allocated + reserved + non-released in MB: 0.0
        - CUDA Device: 1, allocated + reserved + non-released in MB: 5654.66748046875
  • Fix the issue
  • Add tests

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Apr 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124722

Note: Links to docs will display an error until the docs builds have been completed.

❌ 36 New Failures, 22 Unrelated Failures

As of commit 26a2ef7 with merge base bad8d25 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vfdev-5 vfdev-5 force-pushed the inductor-multigpu-cuda-alloc branch from 34abcd5 to 26a2ef7 Compare April 23, 2024 13:32
@@ -1094,7 +1094,13 @@ def fw_compiler_freezing(
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

# partition_fn won't be called
_recursive_joint_graph_passes(aot_autograd_model)
inputs_devices = list(
{i.device for i in pytree.tree_flatten(aot_example_inputs)[0]}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I should avoid fetching device on non-tensor input

Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 23, 2024
@github-actions github-actions bot closed this Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants