Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Unexpected error when upgrading DGL version from 1.1.3 to 2.1.0 #7333

Closed
jalencato opened this issue Apr 20, 2024 · 15 comments · Fixed by #7409
Closed

[Bug] Unexpected error when upgrading DGL version from 1.1.3 to 2.1.0 #7333

jalencato opened this issue Apr 20, 2024 · 15 comments · Fixed by #7409
Assignees

Comments

@jalencato
Copy link

🐛 Bug

When I am switching the DGL version from 1.1.3 to 1.2.1, I have met a problem here:

2024-04-20T00:46:10.970Z	File "/graphstorm/python/graphstorm/model/embed.py", line 401, in forward
2024-04-20T00:46:10.970Z	emb = self.sparse_embeds[ntype](input_nodes[ntype], device)
2024-04-20T00:46:10.970Z	File "/usr/local/lib/python3.8/dist-packages/dgl/distributed/nn/pytorch/sparse_emb.py", line 112, in __call__
2024-04-20T00:46:10.970Z	emb = self._tensor[idx].to(device, non_blocking=True)
2024-04-20T00:46:10.970Z	File "/usr/local/lib/python3.8/dist-packages/dgl/distributed/dist_tensor.py", line 205, in __getitem__
2024-04-20T00:46:10.970Z	return self.kvstore.pull(name=self._name, id_tensor=idx)
2024-04-20T00:46:10.970Z	File "/usr/local/lib/python3.8/dist-packages/dgl/distributed/kvstore.py", line 1453, in pull
2024-04-20T00:46:10.970Z	part_id = self._part_policy[name].to_partid(id_tensor)
2024-04-20T00:46:10.970Z	File "/usr/local/lib/python3.8/dist-packages/dgl/distributed/graph_partition_book.py", line 1096, in to_partid
2024-04-20T00:46:10.970Z	return self._partition_book.nid2partid(id_tensor, self.type_name)
2024-04-20T00:46:10.970Z	File "/usr/local/lib/python3.8/dist-packages/dgl/distributed/graph_partition_book.py", line 789, in nid2partid
2024-04-20T00:46:10.971Z	nids = nids.numpy()
2024-04-20T00:46:10.971Z	TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

My code snippet is like:

self._sparse_embeds[ntype] = DistEmbedding(g.number_of_nodes(ntype),
                                    self.embed_size,
                                    embed_name + '_' + ntype,
                                    init_emb,
                                    part_policy=part_policy)
......

if len(input_nodes[ntype]) == 0:
   dtype = self.sparse_embeds[ntype].weight.dtype
   embs[ntype] = th.zeros((0, self.sparse_embeds[ntype].embedding_dim),
         device=device, dtype=dtype)
   continue
emb = self.sparse_embeds[ntype](input_nodes[ntype], device)

To Reproduce

Steps to reproduce the behavior:

We are getting this error when using graphstorm,

python3 /graphstorm/tools/gen_ogb_dataset.py --savepath /tmp/ogbn-arxiv-nc/ --retain-original-features true

python3 /graphstorm/tools/partition_graph.py --dataset ogbn-arxiv \
                                            --filepath /tmp/ogbn-arxiv-nc/ \
                                            --num-parts 1 \
                                            --num-trainers-per-machine 4 \
                                            --output /tmp/ogbn_arxiv_nc_train_val_1p_4t

python3 -m graphstorm.run.gs_node_classification \
       --workspace /tmp/ogbn-arxiv-nc \
       --num-trainers 1 \
       --num-servers 1 \
       --num-samplers 0 \
       --part-config /tmp/ogbn_arxiv_nc_train_val_1p_4t/ogbn-arxiv.json \
       --ip-config  /tmp/ip_list.txt \
       --ssh-port 22 \
       --cf /graphstorm/training_scripts/gsgnn_np/arxiv_nc.yaml \
       --save-perf-results-path /tmp/ogbn-arxiv-nc/models

Expected behavior

When running on DGL 1.1.3 I did not have the problem here.

Environment

  • DGL Version (e.g., 1.0): 2.1
  • Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): Pytorch 2.1.0 + CUDA 12.1
  • OS (e.g., Linux): Linux
  • How you installed DGL (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version (if applicable): 12
  • GPU models and configuration (e.g. V100): I am running on AWS G4 instance
  • Any other relevant information:

Additional context

@jalencato jalencato changed the title When switching to DGL 2.1.0 + Pytorch 2.1.0 on CUDA 12.0 [Bug] Unexpected error when upgrading DGL version from 1.1.3 to 2.1.0 Apr 20, 2024
@thvasilo
Copy link
Contributor

@Rhett-Ying could you take a look here? I see TODOs listed in this code to replace the numpy operations with torch ones.

@Rhett-Ying
Copy link
Collaborator

I fixed the bug in a6505e8 and it's not merged into DGL 2.1. It's ready in master branch for now. So you could try with latest DGL nightly build.

This fix will be ready in next release DGL 2.2 which will be ready in early May.

@Rhett-Ying Rhett-Ying self-assigned this Apr 23, 2024
@Rhett-Ying
Copy link
Collaborator

The fix is ready in latest DGL 2.2.1. Please try with it.

@thvasilo
Copy link
Contributor

Hi @Rhett-Ying I tried reproducing the example that @jalencato listed and got

Traceback (most recent call last):
  File "/graphstorm/python/graphstorm/run/gsgnn_np/gsgnn_np.py", line 190, in <module>
    main(gs_args)
  File "/graphstorm/python/graphstorm/run/gsgnn_np/gsgnn_np.py", line 143, in main
    trainer.fit(train_loader=dataloader, val_loader=val_dataloader,
  File "/graphstorm/python/graphstorm/trainer/np_trainer.py", line 189, in fit
    self.optimizer.step()
  File "/graphstorm/python/graphstorm/model/gnn.py", line 119, in step
    optimizer.step()
  File "/opt/gs-venv/lib/python3.9/site-packages/dgl/distributed/optim/pytorch/sparse_optim.py", line 355, in step
    alltoall(
  File "/opt/gs-venv/lib/python3.9/site-packages/dgl/distributed/optim/pytorch/utils.py", line 88, in alltoall
    alltoall_cpu(
  File "/opt/gs-venv/lib/python3.9/site-packages/dgl/distributed/optim/pytorch/utils.py", line 26, in alltoall_cpu
    dist.scatter(
  File "/opt/gs-venv/lib/python3.9/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/gs-venv/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 3174, in scatter
    work = default_pg.scatter(output_tensors, input_tensors, opts)
RuntimeError: ProcessGroupGloo::scatter: invalid tensor type at index 0 (expected TensorOptions(dtype=long int, device=cuda:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)), got TensorOptions(dtype=long int, device=cpu, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))

This is on dgl==2.2.1+cu121 and Torch 2.1 running inside the GraphStorm container on g5 instances.

@Rhett-Ying Rhett-Ying reopened this May 14, 2024
@Rhett-Ying
Copy link
Collaborator

@thvasilo are you running on GPU and backend is not nccl?

@thvasilo
Copy link
Contributor

Correct

@Rhett-Ying
Copy link
Collaborator

why not use nccl as backend and does it work well?

@Rhett-Ying
Copy link
Collaborator

this seems to be a new bug and I don't know why it's triggered now.

@thvasilo
Copy link
Contributor

I will try with nccl so far we've only used gloo in GraphStorm AFAIK. Is there a reason to avoid nccl @classicsong ?

@Rhett-Ying
Copy link
Collaborator

could you try to figure out why below tensors are on different device? both of them are supposed to be on cpu? This is the direct cause of the crash, right?

gather_list,
idx_split_size,

@Rhett-Ying
Copy link
Collaborator

Rhett-Ying commented May 14, 2024

DGL master(almost same as 2.2.1 + torch 2.1.0+cu121

I tried to run https://github.com/dmlc/dgl/blob/master/examples/distributed/rgcn/node_classification.py with --num_gpus 4 --sparse-embedding --dgl-sparse --backend gloo which utilize dgl.distributed.optim.SparseAdam(the class which crashed in your case) and it works well with gloo backend. Please note, the example use nccl for gpu training, so manually modifying code to gloo is required.

if use nccl , it crashed with below error:

File "/home/ubuntu/workspace/dgl_1/python/dgl/distributed/optim/pytorch/utils.py", line 86, in alltoall                         
            th.distributed.all_to_all(output_tensor_list, input_tensor_list)th.distributed.all_to_all(output_tensor_list, input_te
nsor_list)th.distributed.all_to_all(output_tensor_list, input_tensor_list) 
No backend type associated with device type cpu

this seems make sense as we support all_to_all_cpu() only?

@Rhett-Ying
Copy link
Collaborator

I’ve reproduced and find the blame commit: 5dfaf99

This commit add device into gather_listand idx_split_size but didn’t think about alltoall supports cpu only if backend is not nccl And device is override by device = grads.device in

this change is merged after DGL 1.1.3, so we hit the issue in DGL 2.2.1

In short, the direct cause is previous tensors are always in cpu, so it works well with gloo. But now, tensors are in gpu while the underlying alltoall call supports cpu tensor only if backend is gloo

@thvasilo
Copy link
Contributor

Hi @Rhett-Ying I ran the repro example that @jalencato posted with the code from #7409 and it works fine now. I think we can close this once that PR is merged.

@Rhett-Ying
Copy link
Collaborator

@thvasilo could you run more examples to make sure that PR does not trigger any other issue?

@thvasilo
Copy link
Contributor

If we merge this we can run our automated integration tests with the daily pip, I don't have the bandwidth to run more manual tests on the PR code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants