Supported nested preproc modules which are called multiple times with different args #2333
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Ran into 3 issues while enabling pipeline for a model:
model._sparse_arch._preproc_module
. Finding a module would not work as this usedgetattr
on the model and swapping the module would fail as this usedsetattr
on the model. Solution:getattr
andsetattr
with_find_preproc_module_recursive
and_swap_preproc_module_recursive
respectively.forward()
. Current logic wouldn't handle this correctly as a) we would only ever created 1 instance of each pipelined preproc with its captured arg info from tracing (even though this should be different for each invocation) and B) we would cache results based on the preproc module's FQN only. Solution:List[ArgInfo]
toPipelinedPreproc
s arg info list viapreproc_module.register_args(preproc_args)
.self._call_idx
toPipelinedPreproc
that gets incremented each time we call fwd, and simply indx into arg info list using this index.self._fqn + str(self._call_idx)
. Ideally, we would have a differentcache_key
for each FQN + arg + kwargs combination, but materializing this into a str / object could be too expensive as these args are large model input KJT / tensors.self.model.constant_value
) as we skip args that aren'ttorch.fx.Node
values. However, we should be able to pipeline these cases. Solution:ArgInfo
calledobjects
of typeList[Optional[object]]
. After fx tracing, you will have fx immutable collections, such astorch.fx.immutable_dict
for immutableDict
. Creating a copy converts it back to mutable original value. So we capture this variable inArgInfo
. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.Reviewed By: joshuadeng
Differential Revision: D61155773