Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try Tensor Parallel on a server equipped with two V100 linked by NVLINK, but got a performance degradation #111

Open
duanzhaol opened this issue Feb 27, 2024 · 8 comments

Comments

@duanzhaol
Copy link

duanzhaol commented Feb 27, 2024

I'm attempting to deploy llama2-7b-chat-hf on a server equipped with two V100 GPUs linked by NVLink, but I've encountered an issue where the tokens per second (token/s) performance is worse when utilizing both GPUs compared to just one. Specifically, with a single GPU, I achieve 29 tokens/s, but this drops to only 17 tokens/s when employing tensor parallelism.
image
I analyzed the process using Nsight Systems, which revealed that the all-reduce operation averages 500µs. This duration significantly exceeds that of the compute operations. Here's a snapshot from the Nsight Systems analysis:
Nsight Systems Analysis

I have confirmed that NVLink is active, yet I'm puzzled by why the communication time is so prolonged. Could this be due to a configuration mistake? Below is a screenshot confirming NVLink's activation:
image
Furthermore, here is the command I used to run the program, explicitly setting NCCL_P2P_LEVEL=NVL to ensure the use of NCCL for peer-to-peer communication. The NCCL log indicates that P2P is being utilized:
NCCL_P2P_LEVEL=NVL NCCL_DEBUG=TRACE OMP_NUM_THREADS=8 ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --checkpoint_path ../llama-2-7b-chat-hf/model.pth --prompt "Hello, my name is" --num_samples 2
Could anyone provide insights into why the communication overhead is so substantial, and if there are any potential configuration errors causing this inefficiency?

@yifuwang
Copy link
Contributor

From the information you provided:

  • The comm kernel is prefixed with ncclDevKernel which suggests P2P is enabled
  • PyTorch's IntraNodeComm is not enabled (IIRC it only supports A100 and onward)

I suggest:

  • Checking nvidia-smi topo -m to see if the GPUs are directly connected
  • Running nsys with --gpu-metrics-device=all to collective the activities of both GPUs. Compare them to see if it's a straggler problem
  • Running a minimal all_reduce microbenchmark to if the issue also presents

@duanzhaol
Copy link
Author

Thank you for your response. I have confirmed the presence of a straggler issue. As illustrated in the attached images, the first GPU remains idle, waiting for the second GPU during the AllReduce operation.
image
However, I am puzzled by the cause of this behavior. I have ensured that no other programs are running on my system that could potentially cause interference. This setup was established by simply cloning this repository and executing the provided code. Could there be a misconfiguration or another underlying issue that I might have overlooked?

@Chillee
Copy link
Contributor

Chillee commented Feb 27, 2024

@duanzhaol I don't think you're using compilation are you?

@duanzhaol
Copy link
Author

duanzhaol commented Feb 28, 2024

@duanzhaol I don't think you're using compilation are you?

Yes, I haven't use compile in my process. Is compile a necessary step for tensor parallel? I think it should work without it.

@Chillee
Copy link
Contributor

Chillee commented Feb 28, 2024

Compilation will significantly reduce the tensor-parallel latency.

In general, gpt-fast will not be particularly fast without using compilation :P

@duanzhaol
Copy link
Author

I opted not to use compilation because my objective is to use tensor parallelism on a serverless platform. The initial compilation process is significantly time-consuming, which becomes impractical in our case since each request necessitates a fresh compilation. This overhead is unacceptable for our use case. If there were a method to persist or checkpoint the results of the compilation—similar to checkpoint an engine in TensorRT—it would greatly enhance efficiency. Unfortunately, I have yet to discover a tool or method that facilitates this capability. Any guidance or suggestions on how to address this challenge would be immensely appreciated

@Chillee
Copy link
Contributor

Chillee commented Feb 28, 2024

@duanzhaol Out of curiosity, what level of overhead is acceptable?

@duanzhaol
Copy link
Author

@duanzhaol Out of curiosity, what level of overhead is acceptable?

Maybe less than a second? In serverless if the function is pure stateless, every request need to recompile the model. And if it is optimized as a Model-as-a-Service platform, compilation will severely restrict the ability to scaled out new instance to handle the burst workloads.

Moreover, I'm still puzzled by the significant straggler issue without ompilation we're encountering. The kernel launch times, according to nsys traces, show considerable variability. Does this problem originate from the implementation of gpt-fast, or is it a broader issue associated with employing tensor parallelism in PyTorch without prior compilation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants