This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
Improve the speed of the pointwise fusion graph pass #17114
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Fixes #17105
This PR fixes 2 problems (that together lead to the huge time spent doing fusion reported in #17105).
The first problem was the
get_subsets
function - a part of the pointwise fusion graph pass. It involves a step that ensures there are no cycles created in a graph after the fusion step. It does so in 2 steps:n
it looks at those separations and creates a set of nodes that are incompatible to be in the same fusion withn
Both of those steps were improved:
n
are also incompatible, then the separation set produced by that node is strictly smaller than the ones produced by those inputs, so there is no need to create itn
, if the incompatible node producing the separation set was already part of the set of nodesn
is incompatible with, then also every node in that separation set was already included there as well. In the repro script from the issue, this change cut the number of insertion trials in the forward fusion pass from ~1M to ~2000.The other problem fixed in this PR is that due to variable scope of
state_ptr
inCachedOp::Forward
method, the state ofCachedOp
was taken twice for any context, resulting in 2 calls into graph optimization methods instead of 1.In my local environment, the time spent on fusion pass in the repro script came down from over 200s to 4.5s on a single GPU.
@leezu @zburning @Caenorst @samskalicky FYI
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
CachedOp::Forward
method where the shared_ptr toCachedOpState
was taken twice, resulting in launching the optimization step twice.