Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When using parallelize=True, raise Runtime Error: expected all tensors to be on the same device #1575

Open
feiba54 opened this issue Mar 14, 2024 · 13 comments
Assignees
Labels
bug Something isn't working.

Comments

@feiba54
Copy link

feiba54 commented Mar 14, 2024

I think there are a few issues being conflated here and it would be helpful to disentangle them:

We support:

  • launching with accelerate launch, which is only meant to support Data-parallel inference (no FSDP, no splitting a model across multiple GPUs).
  • Launching without accelerate launch but with --model_args parallelize=True, which is meant to enable loading a single copy of the model, split across all GPUs you have available.

For all the usecases you are experiencing, the latter option should be what is used. However, my understanding is that when trying to use parallelize=True, the result is RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I'm struggling to reproduce this error however (on the most recent version of the codebase), I will continue to see if I can find a way to replicate it.

Originally posted by @haileyschoelkopf in #1220 (comment)

@feiba54
Copy link
Author

feiba54 commented Mar 14, 2024

Hi I think I am experiencing tha same error using parallelize=True again. I tried both 0.4.0 and 0.4.1 version of the codebase and they were same.

My command is

lm_eval --model "hf" --model_args pretrained=/home/xyf/workspace/pythia --task winogrande --batch_size 8 --num_fewshot 5 --verbosity "DEBUG" --output_path outputs --device cuda

I'm trying to do Model-parallel to load a single copy of the model which is too big to load in single GPU. My GPU environment was 4* A800(80G).
My result is

  File "/home/yxf/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

And this error occurred when running loglikelihood request.

@LSinev
Copy link
Contributor

LSinev commented Mar 14, 2024

usually python reports more lines before error. reporting more lines is more helpful.

By the way, python 3.12.x may produce even more friendly errors in terms of finding the place of last call from package.

@feiba54
Copy link
Author

feiba54 commented Mar 14, 2024

Hi, the full lines are as follows:

2024-03-14:14:19:48,223 INFO     [__main__.py:225] Verbosity set to DEBUG
2024-03-14:14:19:48,223 INFO     [__init__.py:373] lm_eval.tasks.initialize_tasks() is deprecated and no longer necessary. It will be removed in v0.4.2 release. TaskManager will instead be used.
2024-03-14:14:19:52,432 WARNING  [__main__.py:292] File outputs/results.json already exists. Results will be overwritten.
2024-03-14:14:19:52,432 INFO     [__main__.py:311] Selected Tasks: ['winogrande']
2024-03-14:14:19:52,432 INFO     [__main__.py:312] Loading selected tasks...
2024-03-14:14:19:52,434 INFO     [evaluator.py:129] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-03-14:14:19:53,777 INFO     [evaluator.py:190] get_task_dict has been updated to accept an optional argument, `task_manager`Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage
/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/datasets/load.py:1461: FutureWarning: The repository for winogrande contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/winogrande
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
2024-03-14:14:19:59,675 WARNING  [evaluator.py:222] Overwriting default num_fewshot of winogrande from None to 5
2024-03-14:14:19:59,676 DEBUG    [cache.py:33] requests-winogrande is not cached, generating...
2024-03-14:14:19:59,676 INFO     [task.py:395] Building contexts for winogrande on rank 0...
100%|██████████████████████████████████████████████████████████████████████████| 1267/1267 [00:00<00:00, 23858.02it/s]
2024-03-14:14:19:59,760 DEBUG    [evaluator.py:327] Task: winogrande; number of requests on this rank: 2534
2024-03-14:14:19:59,760 INFO     [evaluator.py:357] Running loglikelihood requests
Running loglikelihood requests:   0%|                                                        | 0/2534 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/bin/lm_eval", line 8, in <module>
    sys.exit(cli_evaluate())
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/__main__.py", line 318, in cli_evaluate
    results = evaluator.simple_evaluate(
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/evaluator.py", line 230, in simple_evaluate
    results = evaluate(
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/evaluator.py", line 368, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/api/model.py", line 323, in loglikelihood
    return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/models/huggingface.py", line 1029, in _loglikelihood_tokens
    self._model_call(batched_inps, **call_kwargs), dim=-1
  File "/data/fanyuxuan/lm-evaluation-harness/lm_eval/models/huggingface.py", line 744, in _model_call
    return self.model(inps).logits
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 1036, in forward
    outputs = self.gpt_neox(
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 897, in forward
    inputs_embeds = self.embed_in(input_ids)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/home/fanyuxuan/anaconda3/envs/eval-vanilla/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
Running loglikelihood requests:   0%|                                                        | 0/2534 [00:00<?, ?it/s]

@haileyschoelkopf
Copy link
Contributor

Hi @feiba54 , have you tried the most recent version of the codebase from main?

@feiba54
Copy link
Author

feiba54 commented Mar 14, 2024

Hi @feiba54 , have you tried the most recent version of the codebase from main?

yes, it's the 0.4.1 version right? I git cloned it yesterday

@haileyschoelkopf
Copy link
Contributor

Could you try to rerun with a model that is public, and also see what happens if you remove --device cuda ? I'm not sure what /home/xyf/workspace/pythia refers to.

@haileyschoelkopf haileyschoelkopf added the bug Something isn't working. label Mar 14, 2024
@feiba54
Copy link
Author

feiba54 commented Mar 15, 2024

Hi I tried with and without --device cuda and they were the same.

Here are some additional information when using parallelize=True:

  • On A800(80G)x4, it seems smaller models(that can be fit in single GPU, like EleutherAI/pythia-70m) will raise this expected all tensors to be on the same device, but on larger models that have to be loaded on several GPUs (like Qwen/Qwen1.5-72B it works fine. I think this should be consistent because sometimes we need to test a series of model of different sizes altogether.
  • On V100(16G)x8 I tested on "01-ai/Yi-34B", since V100 does not support bfloat16 so I converted to float32. However, I run into OOM error. For EleutherAI/pythia-70m, I still met expected all tensors to be on the same device.
2024-03-15:15:23:17,349 INFO     [evaluator.py:348] Running loglikelihood requests
Task_name: winogrande. Task: <abc.winograndeConfigurableTask object at 0x7efbd4f818d0>. Group: None. 
versions: defaultdict(<class 'dict'>, {'winogrande': 'Yaml'}). configs: defaultdict(<class 'dict'>, {'winogrande': {'task': 'winogrande', 'dataset_path': 'winogrande', 'dataset_name': 'winogrande_xl', 'training_split': 'train', 'validation_split': 'validation', 'doc_to_text': '<function doc_to_text at 0x7efbd6488820>', 'doc_to_target': '<function doc_to_target at 0x7efbd6488af0>', 'doc_to_choice': '<function doc_to_choice at 0x7efbd6488dc0>', 'description': '', 'target_delimiter': ' ', 'fewshot_delimiter': '\n\n', 'num_fewshot': 5, 'metric_list': [{'metric': 'acc', 'aggregation': 'mean', 'higher_is_better': True}], 'output_type': 'multiple_choice', 'repeats': 1, 'should_decontaminate': True, 'doc_to_decontamination_query': 'sentence', 'metadata': {'version': 1.0}}}). 

  0%|          | 0/2534 [00:00<?, ?it/s]
  0%|          | 1/2534 [00:08<5:50:32,  8.30s/it]
  0%|          | 2/2534 [00:12<4:19:55,  6.16s/it]Traceback (most recent call last):
  File "/home/aiscuser/.local/bin/lm_eval", line 8, in <module>
    sys.exit(cli_evaluate())
  File "/scratch/amlt_code/lm_eval/__main__.py", line 231, in cli_evaluate
    results = evaluator.simple_evaluate(
  File "/scratch/amlt_code/lm_eval/utils.py", line 415, in _wrapper
    return fn(*args, **kwargs)
  File "/scratch/amlt_code/lm_eval/evaluator.py", line 150, in simple_evaluate
    results = evaluate(
  File "/scratch/amlt_code/lm_eval/utils.py", line 415, in _wrapper
    return fn(*args, **kwargs)
  File "/scratch/amlt_code/lm_eval/evaluator.py", line 372, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs) # TODO:
  File "/scratch/amlt_code/lm_eval/models/huggingface.py", line 782, in loglikelihood
    return self._loglikelihood_tokens(new_reqs)
  File "/scratch/amlt_code/lm_eval/models/huggingface.py", line 996, in _loglikelihood_tokens
    self._model_call(batched_inps, **call_kwargs), dim=-1
  File "/scratch/amlt_code/lm_eval/models/huggingface.py", line 713, in _model_call
    return self.model(inps).logits
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aiscuser/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/aiscuser/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1195, in forward
    logits = self.lm_head(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aiscuser/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/home/aiscuser/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 347, in pre_forward
    set_module_tensor_to_device(
  File "/home/aiscuser/.local/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 387, in set_module_tensor_to_device
    new_value = value.to(device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.71 GiB. GPU 0 has a total capacty of 15.77 GiB of which 965.12 MiB is free. Process 1140220 has 14.83 GiB memory in use. Of the allocated memory 12.16 GiB is allocated by PyTorch, and 1.71 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

  0%|          | 2/2534 [00:16<5:53:40,  8.38s/it]

@haileyschoelkopf
Copy link
Contributor

Thanks!

Can you clarify the commands you are running again, and also what transformers version is being used? I did not see parallelize=True in the command you pasted earlier which is surprising to me.

@RylanSchaeffer
Copy link
Contributor

RylanSchaeffer commented Mar 22, 2024

I have the same error when using parallelize=True:

Error: Command '
            lm_eval --model hf                 --model_args pretrained=meta-llama/Llama-2-70b-hf,trust_remote_code=True,parallelize=True                 --tasks persona_believes-abortion-should-be-illegal                 --batch_size auto:4                 --output_path eval_results/Llama2_70B_1.92T/persona_believes-abortion-should-be-illegal                 --log_samples
            ' returned non-zero exit status 1.
Running Command:  
            lm_eval --model hf                 --model_args pretrained=meta-llama/Llama-2-70b-hf,trust_remote_code=True,parallelize=True                 --tasks swag                 --batch_size auto:4                 --output_path eval_results/Llama2_70B_1.92T/swag                 --log_samples
            
2024-03-21:20:59:46,964 INFO     [__main__.py:225] Verbosity set to INFO
2024-03-21:20:59:46,964 INFO     [__init__.py:373] lm_eval.tasks.initialize_tasks() is deprecated and no longer necessary. It will be removed in v0.4.2 release. TaskManager will instead be used.
2024-03-21:20:59:50,732 INFO     [__main__.py:311] Selected Tasks: ['swag']
2024-03-21:20:59:50,733 INFO     [__main__.py:312] Loading selected tasks...
2024-03-21:20:59:50,738 INFO     [evaluator.py:129] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
2024-03-21:20:59:50,867 WARNING  [logging.py:61] Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Loading checkpoint shards: 100%|██████████| 15/15 [00:01<00:00, 11.00it/s]
2024-03-21:20:59:53,796 INFO     [evaluator.py:190] get_task_dict has been updated to accept an optional argument, `task_manager`Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage
2024-03-21:20:59:59,125 INFO     [task.py:395] Building contexts for swag on rank 0...
100%|██████████| 20006/20006 [00:08<00:00, 2284.01it/s]
2024-03-21:21:00:08,922 INFO     [evaluator.py:357] Running loglikelihood requests
Running loglikelihood requests:   0%|          | 0/80024 [00:00<?, ?it/s]Passed argument batch_size = auto:4.0. Detecting largest batch size
Traceback (most recent call last):
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/bin/lm_eval", line 8, in <module>
    sys.exit(cli_evaluate())
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/__main__.py", line 318, in cli_evaluate
    results = evaluator.simple_evaluate(
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/evaluator.py", line 230, in simple_evaluate
    results = evaluate(
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/utils.py", line 288, in _wrapper
    return fn(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/evaluator.py", line 368, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/api/model.py", line 323, in loglikelihood
    return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/huggingface.py", line 933, in _loglikelihood_tokens
    for chunk in chunks:
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/utils.py", line 427, in get_batched
    yield from batch
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/utils.py", line 610, in get_chunks
    if len(arr) == (fn(i, _iter) if fn else n):
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/huggingface.py", line 866, in _batch_scheduler
    self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/huggingface.py", line 643, in _detect_batch_size
    batch_size = forward_batch()
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/accelerate/utils/memory.py", line 136, in decorator
    return function(batch_size, *args, **kwargs)
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/huggingface.py", line 638, in forward_batch
    out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)  # noqa: F841
  File "/lfs/skampere1/0/rschaef/KoyejoLab-Predictable-LLM-Evals/submodules/lm-evaluation-harness/lm_eval/models/huggingface.py", line 744, in _model_call
    return self.model(inps).logits
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1168, in forward
    outputs = self.model(
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 966, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/lfs/skampere1/0/rschaef/miniconda3/envs/pred_llm_evals_env/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

@haileyschoelkopf
Copy link
Contributor

@feiba54 Confirming I've been able to reproduce this on my end now (with Pythia-70m). Investigating possible fixes!

@haileyschoelkopf haileyschoelkopf self-assigned this Mar 22, 2024
@haileyschoelkopf
Copy link
Contributor

A temporary workaround is, if the error is

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

to pass --device cuda:3 and avoid this issue. TODO is to see if HF exposes a way we can get the device the first layer / input embedding layer is on programmatically, without needing to know this attr's name.

@jiaqiw09
Copy link
Contributor

jiaqiw09 commented May 6, 2024

A temporary workaround is, if the error is

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

to pass --device cuda:3 and avoid this issue. TODO is to see if HF exposes a way we can get the device the first layer / input embedding layer is on programmatically, without needing to know this attr's name.

Hi, I just met the same problem. And here is the issuse related to device_map='auto' feature, I hope it can help something
huggingface/transformers#24410

@monk1337
Copy link

Thanks!

Can you clarify the commands you are running again, and also what transformers version is being used? I did not see parallelize=True in the command you pasted earlier which is surprising to me.

As of now setting higher cuda works with parallelize=True, is there any fix if I am setting parallelize=True do I still need to set cuda?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working.
Projects
Status: In progress
Development

No branches or pull requests

6 participants