Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][C10D] Avoid creating new nccl communicator for each P2P pair #129140

Open
wconstab opened this issue Jun 20, 2024 · 0 comments
Open

[RFC][C10D] Avoid creating new nccl communicator for each P2P pair #129140

wconstab opened this issue Jun 20, 2024 · 0 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wconstab
Copy link
Contributor

wconstab commented Jun 20, 2024

Motivation

Inside ProcessGroupNCCL::pointToPoint (which is the impl for all send/recv and batched send/recv operations), we have logic that creates (and caches) a unique nccl communicator for any pair of sender/recvr. This seems pointless, since the send/recv op could use the same communicator as is used for all collectives on this PG, and creating new communicators is expensive both in nccl resources and runtime.

The reasons for this are (1) legacy behavior for PGNccl was to lazily initialize nccl communicators on first collective, so we cannot be sure that the collective communicator for the PG has been created yet upon the first call to PointToPoint, and (2) in pointToPoint we do not require all ranks in the PG to collectively call the pointToPoint API (which we do require for any collective API), and as such we cannot lazily create the full collective communicator.

Proposal

The proposal takes advantage of an existing feature called 'eager initialization'. When users call init_process_group, they can optionally pass a 'device_id' argument, which then enables/causes the initialization to happen eagerly. This means the root communicator is created and can be used by the first operation rather than the first operation having to create it. Additionally, any call to new_group would eagerly call comm_split if it is available in NCCL, or eagerly create a new communicator otherwise. Hence for all pytorch PGs, we can assume the 'PG communicator' exists at the time of a call to pointToPoint in this mode.

  1. if inside pointToPoint we can detect that the parent communicator has been initialized already, we use it preferentially.
  2. otherwise, if users do not opt into eager PG initialization (do not pass device_id at init_process_group time), we preserve functionality for them by falling back to the existing behavior inside pointToPoint and creating new pair-wise communicators. we also issue a warning via WARN_ONCE suggesting to opt-into eager initialization.

For (2), I think the easiest way to implement this is to modify ProcessGroupNCCL::getNCCLComm to accept an argument 'cached_only' which defaults to false. When set to true, getNCCLComm(cached_only=True) would return an existing communicator or nullopt. This would let pointToPoint throw a warning and then create the backup pairwise communicator for the legacy case.

Amended Proposal to address comm stream and overlap

For P2P ops, we want to allow using a separate cuda stream from the stream that collective ops or p2p ops between other peers run on, facilitating overlap between those comm operations. This can be enabled by decoupling 'stream' and 'ncclcomm', both of which are controlled by ProcessGroupNCCL today and bundled together. The proposal there is to unbundle them.

  • getNCCLComm is a helper that takes a 'key' and uses that to either find a cache hit or create new stuff. The stuff includes a ncclComm but also a stream and an event on that stream.
  • we can add a new kw arg to getNCCLComm called 'streamKey' and let the caller decide whether to use the 'default key' (e.g. for collectives), or a special p2p key that will be unique per p2p pair
  • we amend the above proposal so in the case a communicator exists, we still use it (passing key=device, stream_key=p2p_ranks) - or if that doesn't find an existing communicator, we create a new one (passing key=p2p_ranks, stream_key=p2p_ranks).
    *at all call sites where we issue collectives, we'll need to audit which key is used to access the stream cache and use the p2p one where appropriate.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k

@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 21, 2024
@wz337 wz337 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 24, 2024
wconstab added a commit that referenced this issue Jun 25, 2024
Users that opt-into eager initialization (enabled by passing device_id
to init_process_group) will now be able to take advantage of reusing
the existing communicator for the processgroup for send/recv ops rather
than creating new 2-rank communicators for every pair of ranks
performing send/recv.

Existing users not passing device_id to init_process_group will now get
a warning suggesting they do so, but they will still get the
functionality they have today, automatic creation of pair-wise
communicators.

When reusing an existing communicator, a dedicated nccl stream will
still be used for each pair of P2P ranks so that pair-wise comm ops can
overlap with each other rather than being serialized on a single stream
per PG.

Fixes #129140

ghstack-source-id: 3db38c68ea6a4947ef4a3f9fa61fc4865513f63c
Pull Request resolved: #129147
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants