Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supported nested preproc modules which are called multiple times with different args #2333

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sarckk
Copy link
Member

@sarckk sarckk commented Aug 22, 2024

Summary:
Ran into 3 issues while enabling pipeline for a model:

  1. Current pipeline logic for finding and swapping a preproc module only works if the preproc module exists at model level. If the preproc is within a model's child modules, this logic would break down e.g. model._sparse_arch._preproc_module. Finding a module would not work as this used getattr on the model and swapping the module would fail as this used setattr on the model. Solution:
    • Replaced getattr and setattr with _find_preproc_module_recursive and _swap_preproc_module_recursive respectively.
  2. In this model, the same preproc module was called 2 times with 2 different sets of arguments passed to forward(). Current logic wouldn't handle this correctly as a) we would only ever created 1 instance of each pipelined preproc with its captured arg info from tracing (even though this should be different for each invocation) and B) we would cache results based on the preproc module's FQN only. Solution:
    • If we see another instance of PipelinedPreproc call, we still capture its argument's graph and add the List[ArgInfo] to PipelinedPreprocs arg info list via preproc_module.register_args(preproc_args).
    • Each time we call pipelined preproc forward during pipeline execution, we need to fetch the right arg info list. So I added self._call_idx to PipelinedPreproc that gets incremented each time we call fwd, and simply indx into arg info list using this index.
    • Changed the cache key to self._fqn + str(self._call_idx). Ideally, we would have a different cache_key for each FQN + arg + kwargs combination, but materializing this into a str / object could be too expensive as these args are large model input KJT / tensors.
  3. Logic doesn't support if an arg to a preproc module is a constant (e.g. self.model.constant_value) as we skip args that aren't torch.fx.Node values. However, we should be able to pipeline these cases. Solution:
    • Add a new field to ArgInfo called objects of type List[Optional[object]]. After fx tracing, you will have fx immutable collections, such as torch.fx.immutable_dict for immutable Dict. Creating a copy converts it back to mutable original value. So we capture this variable in ArgInfo. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.

Reviewed By: joshuadeng

Differential Revision: D61155773

yhshin and others added 2 commits August 21, 2024 11:11
Summary: Add missing pipelline_preproc and custom_moel_fwd args.

Differential Revision: D61564467
… different args

Summary:
Ran into 3 issues while enabling pipeline for a model:
1) Current pipeline logic for finding and swapping a preproc module only works if the preproc module exists at model level. If the preproc is within a model's child modules, this logic would break down e.g. `model._sparse_arch._preproc_module`. Finding a module would not work as this used `getattr` on the model and swapping the module would fail as this used `setattr` on the model. Solution:
   - Replaced `getattr` and `setattr` with `_find_preproc_module_recursive` and `_swap_preproc_module_recursive` respectively.
2) In this model, the same preproc module was called 2 times with 2 **different** sets of arguments passed to `forward()`. Current logic wouldn't handle this correctly as a) we would only ever created 1 instance of each pipelined preproc with its captured arg info from tracing (even though this should be different for each invocation) and B) we would cache results based on the preproc module's FQN only. Solution:
   - If we see another instance of PipelinedPreproc call, we still capture its argument's graph and add the `List[ArgInfo]` to `PipelinedPreproc`s arg info list via `preproc_module.register_args(preproc_args)`.
    - Each time we call pipelined preproc forward during pipeline execution, we need to fetch the right arg info list. So I added `self._call_idx` to `PipelinedPreproc` that gets incremented each time we call fwd, and simply indx into arg info list using this index.
    - Changed the cache key to `self._fqn + str(self._call_idx)`. Ideally, we would have a different `cache_key` for each FQN + arg + kwargs combination, but materializing this into a str / object could be too expensive as these args are large model input KJT / tensors.
3) Logic doesn't support if an arg to a preproc module is a constant (e.g. `self.model.constant_value`) as we skip args that aren't `torch.fx.Node` values. However, we should be able to pipeline these cases. Solution:
    - Add a new field to `ArgInfo` called `objects` of type `List[Optional[object]]`. After fx tracing, you will have fx immutable collections, such as `torch.fx.immutable_dict` for immutable `Dict`. Creating a copy converts it back to mutable original value. So we capture this variable in `ArgInfo`. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.

Reviewed By: joshuadeng

Differential Revision: D61155773
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 22, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61155773

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants