Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for expandable segments with cuda graph trees #128068

Conversation

bilal2vec
Copy link
Contributor

@bilal2vec bilal2vec commented Jun 5, 2024

This PR adds support to use expandable segments with private memory pools which should unblock using it with cuda graphs and cuda graph trees. Currently, the allocator silently avoids using expandable segments when allocating in a private pool due to checkpoint saving/restoring not meshing well with how we keep track of unmapped blocks.

The PR itself is pretty short, most of the logic for checkpointing and reapplying state for non-expandable segments transfers over without much work.

Expandable segments reserve a virtual address space of size equal to the amount of physical memory on the GPU. Every time we want to malloc() or free() memory in a memory pool with expandable segments turned on, we map/unmap pages of physical GPU memory under the hood to create a new block that we return to the caller. This is beneficial due to the fact that each memory pool functions as a single segment of memory with a contiguous block of memory addresses that can grow and shrink as needed, avoiding fragmentation from allocating multiple non-contiguous segments that may not be merged together.

The caching allocator handles this by creating an unmapped block for the entire reserved virtual address space at init, which is treated similarly to an unallocated block in a free pool. When callers call malloc(), it's split and mapped to create allocated blocks, and calling free() similarly caches and merges free blocks in a free pool to be used later. Expandable blocks are unmapped and returned back to Cuda when they are cleaned up, or when we hit an OOM and the allocator attempts to remap cached free blocks. The code paths to map, free, and unmap blocks in expandable segments is similar to that for normal blocks and does all the same work of updating stats on memory usage, moving blocks between active and free pools, and returning memory to Cuda.

With Cuda Graph Trees and private memory pools, we need the ability to take checkpoints of the current state of the memory allocator after each graph capture as well as reapplying the state before capturing a new graph after replaying a captured graph so that the new cuda graph capture has access to the state of the allocator at the point after replaying a previously captured graph so it can reuse empty blocks and allocate new ones.

As mentioned in a below comment, memory in a private pool is cached until the private pool is destroyed and allocations can only grow from extra graph captures, any freeing of memory would result in invalid memory addresses and would break cuda graphs.

One implementation detail to note for unmapped blocks with expandable segments is that unmapped blocks are kept track in a member variable unmapped of a BlockPool. unmapped is not part of the checkpointed state of the caching allocator and isn't restored when reapplying checkpoints since we never free/unmap memory back to cuda and is persisted across graph captures / replays.

Checkpointing the current state of the memory allocator works as expected with expandable segments. Checkpointing grabs the first block of every segment in the active and free pools of the private pool and traverses the linked list of blocks in the segment to capture the state of every segment, which is then saved and kept for when it is needed to be reapplied. For expandable blocks, the last block in every segment will be an unallocated unmapped block containing the remaining amount of unmapped memory at graph capture time, and this too is saved in the checkpoint.

Reapplying the checkpoints works by freeing all allocated blocks and merging them into a single block per segment, then for each segment, we manually split and allocate all blocks from the checkpoint and then free the blocks marked as unallocated in the checkpoint state. For expandable segments, we need to make some modifications to not split unmapped blocks and avoid manually mapping then freeing unmapped blocks.

cc @mcarilli @ezyang @eellison @peterbell10 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Jun 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128068

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Unrelated Failures

As of commit f2d0b30 with merge base 6f275ae (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jun 5, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@bilal2vec bilal2vec marked this pull request as ready for review June 10, 2024 18:32
@bilal2vec bilal2vec requested review from eqy and a team as code owners June 10, 2024 18:32
@eqy eqy added module: cuda graphs Ability to capture and then replay streams of CUDA kernels ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request labels Jun 10, 2024
@eqy
Copy link
Collaborator

eqy commented Jun 10, 2024

CC @zdevito @eellison for review

@eellison eellison self-requested a review June 10, 2024 20:20
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 10, 2024
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison should check that the checkpoint restore state will work. I do not know enough about the requirements in that codepath. I added some minor comments inline.

c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
test/test_cuda_expandable_segments.py Outdated Show resolved Hide resolved
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for the pr! Would you mind giving a high-level description of how/why this works? For instance, i would say something along the following for the prior checkpointing:

PrivatePools contain segment allocations which are never released for the duration of the private pool. The allocations can only ever grow. If we capture its state and then continue to record, it is always valid to resume a prior state because all of the previous allocations must still be contained within the private pool. New allocated segments since that point will be unallocated in the resumption.
comment here explains the rest of it.

I saw that you re-invoked some of the tests with expandable segments, that's great. Would it be possible to add a test where we re-map the expandable segment? I'm not sure this would occur today. I believe there's a way of artificially limiting memory, that might help.

Also, just to confirm @zdevito :

If we have an initial graph recording with an expandable block allocated to : data_ptr_a

Then, in some future state, the block needs to be cuMemMap_. When we then checkpoint the back to the first graph recording, will that data_ptr still be valid ?

c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
test/test_cuda_expandable_segments.py Outdated Show resolved Hide resolved
@bilal2vec
Copy link
Contributor Author

Thanks @eellison for the review! will do — I'm trying to reason about the example you gave, IIUC I'd expect remapping a block in a subsequent graph to break when replaying the first graph, right? since that block's pointer will no longer be valid (for both expandable/nonexpandable blocks).

IIUC remapping should only happen when you go down this code path and need to release blocks from the free pools (expandable or not) back to cuda?

@eellison eellison requested a review from BoyuanFeng June 13, 2024 17:59
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I talked a bit with @zdevito, who will chime in.

IIUC remapping should only happen when you go down this code path and need to release blocks from the free pools (expandable or not) back to cuda?

Right - I mistakenly thought there was a single ExpandableSegment per device - it's actually per private pool. We should not unmap in this scenario.

Zach mentioned one nuance around the final block in Expandable Segments which might need to handle in checkpointing.

Still need to fix test errors though

torch/_inductor/cudagraph_trees.py Show resolved Hide resolved
@eqy
Copy link
Collaborator

eqy commented Jun 14, 2024

Probably needs a rebase to retrigger CI

@bilal2vec
Copy link
Contributor Author

bilal2vec commented Jun 14, 2024

Zach mentioned one nuance around the final block in Expandable Segments which might need to handle in checkpointing.

yes — This is what the bulk of the issues i had w this were about and I think the behavior should be correct now, let me write a proper high-level overview for this but let me know if I've missed something!

@bilal2vec bilal2vec force-pushed the bilal_cudagraphtrees_expandable_segments branch from 59318e7 to 5557f97 Compare June 14, 2024 20:34
@bilal2vec
Copy link
Contributor Author

bilal2vec commented Jun 14, 2024

Todo(bilal2vec): There is one missing part to this that I still need to fix: if you capture graph a -> graph b and both graphs map memory, then replay graph a and reapply checkpoint state in preparation for capturing graph c, the reapplied checkpoint state will assume i have the same amount of unmapped memory as I did at end of graph a, not the correct amount (amount at end of graph b) and should have an extra unallocated but mapped block of memory equal to the amount of memory mapped by graph b This is not an issue—graph breaks automatically return all freed blocks (expandable or otherwise) to Cuda.

@bilal2vec
Copy link
Contributor Author

bilal2vec commented Jun 18, 2024

Looks like some tests are failing because expandable segments are silently(?) unsupported on windows? A check to disable expandable segments based on a cmake flag is disabled on windows. slightly related to #122057

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unmapped blocks needed to be treated a bit differently when restoring a segment: they cannot be changed by just restoring the state of segment because they specify the state of the underlying cuda memory. Changing where an unmapped block is would require restoring that state.

Instead, we should treat unmapped blocks as not changeable when restoring a segment. The restore logic should check they appear in the correct spot but otherwise skip over them.

c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
c10/cuda/CUDACachingAllocator.cpp Outdated Show resolved Hide resolved
@eellison eellison requested a review from zdevito June 25, 2024 15:53
@eqy
Copy link
Collaborator

eqy commented Jun 25, 2024

Probably needs a rebase for the inductor ImportError failures

@bilal2vec
Copy link
Contributor Author

@pytorchmergebot rebase

Copy link

pytorch-bot bot commented Jun 25, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@bilal2vec bilal2vec force-pushed the bilal_cudagraphtrees_expandable_segments branch from 214ae2b to f2d0b30 Compare July 15, 2024 17:58
@bilal2vec
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / linux-focal-cuda11.8-py3.10-gcc9-debug / test (default, 2, 5, linux.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator

eqy commented Jul 15, 2024

@pytorchmergebot merge -f "failure appears unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
This PR adds support to use expandable segments with private memory pools which should unblock using it with cuda graphs and cuda graph trees. Currently, the allocator silently avoids using expandable segments when allocating in a private pool due to checkpoint saving/restoring not meshing well with how we keep track of unmapped blocks.

The PR itself is pretty short, most of the logic for checkpointing and reapplying state for non-expandable segments transfers over without much work.

Expandable segments reserve a virtual address space of size equal to the amount of physical memory on the GPU. Every time we want to `malloc()` or `free()` memory in a memory pool with expandable segments turned on, we map/unmap pages of physical GPU memory under the hood to create a new block that we return to the caller. This is beneficial due to the fact that each memory pool functions as a single segment of memory with a contiguous block of memory addresses that can grow and shrink as needed, avoiding fragmentation from allocating multiple non-contiguous segments that may not be merged together.

The caching allocator handles this by creating an unmapped block for the entire reserved virtual address space at init, which is treated similarly to an unallocated block in a free pool. When callers call `malloc()`, it's split and mapped to create allocated blocks, and calling `free()` similarly caches and merges free blocks in a free pool to be used later. Expandable blocks are unmapped and returned back to Cuda when they are cleaned up, or when we hit an OOM and the allocator attempts to remap cached free blocks. The code paths to map, free, and unmap blocks in expandable segments is similar to that for normal blocks and does all the same work of updating stats on memory usage, moving blocks between active and free pools, and returning memory to Cuda.

With Cuda Graph Trees and private memory pools, we need the ability to take checkpoints of the current state of the memory allocator after each graph capture as well as reapplying the state before capturing a new graph after replaying a captured graph so that the new cuda graph capture has access to the state of the allocator at the point after replaying a previously captured graph so it can reuse empty blocks and allocate new ones.

As mentioned in a below comment, memory in a private pool is cached until the private pool is destroyed and allocations can only grow from extra graph captures, any freeing of memory would result in invalid memory addresses and would break cuda graphs.

One implementation detail to note for unmapped blocks with expandable segments is that unmapped blocks are kept track in a member variable `unmapped` of a `BlockPool`. `unmapped` is *not* part of the checkpointed state of the caching allocator and isn't restored when reapplying checkpoints since we never free/unmap memory back to cuda and is persisted across graph captures / replays.

Checkpointing the current state of the memory allocator works as expected with expandable segments. Checkpointing grabs the first block of every segment in the active and free pools of the private pool and traverses the linked list of blocks in the segment to capture the state of every segment, which is then saved and kept for when it is needed to be reapplied. For expandable blocks, the last block in every segment will be an unallocated unmapped block containing the remaining amount of unmapped memory at graph capture time, and this too is saved in the checkpoint.

Reapplying the checkpoints works by freeing all allocated blocks and merging them into a single block per segment, then for each segment, we manually split and allocate all blocks from the checkpoint and then free the blocks marked as unallocated in the checkpoint state. For expandable segments, we need to make some modifications to not split unmapped blocks and avoid manually mapping then freeing unmapped blocks.

Pull Request resolved: pytorch#128068
Approved by: https://github.com/zdevito, https://github.com/eqy
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…h#128068)"

This reverts commit fdc8361.

Reverted pytorch#128068 on behalf of https://github.com/janeyx99 due to Reverting for breaking ROCm tests on trunk, I think the tests need to be qualified with @onlyCUDA ([comment](pytorch#128068 (comment)))
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
This PR adds support to use expandable segments with private memory pools which should unblock using it with cuda graphs and cuda graph trees. Currently, the allocator silently avoids using expandable segments when allocating in a private pool due to checkpoint saving/restoring not meshing well with how we keep track of unmapped blocks.

The PR itself is pretty short, most of the logic for checkpointing and reapplying state for non-expandable segments transfers over without much work.

Expandable segments reserve a virtual address space of size equal to the amount of physical memory on the GPU. Every time we want to `malloc()` or `free()` memory in a memory pool with expandable segments turned on, we map/unmap pages of physical GPU memory under the hood to create a new block that we return to the caller. This is beneficial due to the fact that each memory pool functions as a single segment of memory with a contiguous block of memory addresses that can grow and shrink as needed, avoiding fragmentation from allocating multiple non-contiguous segments that may not be merged together.

The caching allocator handles this by creating an unmapped block for the entire reserved virtual address space at init, which is treated similarly to an unallocated block in a free pool. When callers call `malloc()`, it's split and mapped to create allocated blocks, and calling `free()` similarly caches and merges free blocks in a free pool to be used later. Expandable blocks are unmapped and returned back to Cuda when they are cleaned up, or when we hit an OOM and the allocator attempts to remap cached free blocks. The code paths to map, free, and unmap blocks in expandable segments is similar to that for normal blocks and does all the same work of updating stats on memory usage, moving blocks between active and free pools, and returning memory to Cuda.

With Cuda Graph Trees and private memory pools, we need the ability to take checkpoints of the current state of the memory allocator after each graph capture as well as reapplying the state before capturing a new graph after replaying a captured graph so that the new cuda graph capture has access to the state of the allocator at the point after replaying a previously captured graph so it can reuse empty blocks and allocate new ones.

As mentioned in a below comment, memory in a private pool is cached until the private pool is destroyed and allocations can only grow from extra graph captures, any freeing of memory would result in invalid memory addresses and would break cuda graphs.

One implementation detail to note for unmapped blocks with expandable segments is that unmapped blocks are kept track in a member variable `unmapped` of a `BlockPool`. `unmapped` is *not* part of the checkpointed state of the caching allocator and isn't restored when reapplying checkpoints since we never free/unmap memory back to cuda and is persisted across graph captures / replays.

Checkpointing the current state of the memory allocator works as expected with expandable segments. Checkpointing grabs the first block of every segment in the active and free pools of the private pool and traverses the linked list of blocks in the segment to capture the state of every segment, which is then saved and kept for when it is needed to be reapplied. For expandable blocks, the last block in every segment will be an unallocated unmapped block containing the remaining amount of unmapped memory at graph capture time, and this too is saved in the checkpoint.

Reapplying the checkpoints works by freeing all allocated blocks and merging them into a single block per segment, then for each segment, we manually split and allocate all blocks from the checkpoint and then free the blocks marked as unallocated in the checkpoint state. For expandable segments, we need to make some modifications to not split unmapped blocks and avoid manually mapping then freeing unmapped blocks.

Pull Request resolved: pytorch#128068
Approved by: https://github.com/eqy, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: dynamo module: inductor open source release notes: inductor Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants