-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[After Rebase] Top of Traceable FSDP2 stack #128996
base: gh/yf225/46/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128996
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 30 New Failures, 3 Unrelated FailuresAs of commit cf6e033 with merge base 734891a (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 7cd520d3e76cca56d9e7ee60ec0b12cabf7c2cee Pull Request resolved: #128996
torch/_dynamo/output_graph.py
Outdated
name, | ||
**options, | ||
) | ||
return self.side_effects.track_object_existing(target, vt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't need this if tracing into inbuilt nn module
torch/_dynamo/utils.py
Outdated
else: | ||
proxy = mod.__class__.__new__(mod.__class__) | ||
proxy.__dict__ = mod.__dict__ | ||
return proxy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't need this anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but maybe need some changes to existing code? (like the places that are still using nn_module_proxy()
?)
torch/_dynamo/variables/builder.py
Outdated
value, | ||
name, | ||
source=self.get_source(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't rely on tx.output.nn_modules
torch/_dynamo/variables/nn_module.py
Outdated
@@ -222,6 +222,9 @@ def _custom_getattr_fallback(self, base, tx, name, options): | |||
if not isinstance(getattr_fn, types.FunctionType): | |||
unimplemented("torch.nn.Module with a non-function custom __getattr__") | |||
|
|||
if getattr(base, "_is_fsdp_managed_module", False): | |||
from .builder import VariableBuilder | |||
return VariableBuilder(tx, options["source"])(getattr_fn(base, name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably don't need this
torch/_dynamo/variables/nn_module.py
Outdated
@@ -317,6 +320,9 @@ def var_getattr(self, tx, name): | |||
elif is_safe_constant(subobj) or istensor(subobj): | |||
# Support possibly common cases of class members | |||
return VariableBuilder(tx, NNModuleSource(source))(subobj) | |||
elif istype(subobj, types.GetSetDescriptorType): | |||
assert source | |||
return VariableBuilder(tx, source)(subobj.__get__(base)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy this to UnspecializedNNModuleVariable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to know for what function it happened, then we can send a separate PR for that
torch/_dynamo/variables/nn_module.py
Outdated
assert_const=False, | ||
) | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need these
torch/_dynamo/variables/nn_module.py
Outdated
_named_embed, | ||
tx=tx, | ||
key=key, | ||
source_cls=FSDPNNModuleSource, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need all these, but need a different way to propagate FSDPNNModuleSource
(this is mostly perf-only and shouldn't affect functionality, so we can do this propagation work after aot_eager unit test is landed)
torch/_dynamo/variables/nn_module.py
Outdated
return super().var_getattr(tx, name) | ||
|
||
def as_python_constant(self): | ||
return self.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not really a constant lol
torch/_dynamo/variables/nn_module.py
Outdated
return variables.LambdaVariable( | ||
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) | ||
) | ||
return super().var_getattr(tx, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove
torch/_dynamo/variables/nn_module.py
Outdated
def call_function( | ||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | ||
) -> "VariableTracker": | ||
return super().call_function(tx, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove
torch/_dynamo/variables/nn_module.py
Outdated
source=NNModuleSource(_gen_source(source, name)), | ||
), | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can put them back
torch/_dynamo/eval_frame.py
Outdated
# TODO(yf225): this is a workaround to allow inplace fully-sharded module to | ||
# still go into this branch (instead of the second branch). | ||
# If we don't do this, `torch.compile(fully_shard(module_from_user_defined_module_class))` will ignore all module hooks which will break FSDP tracing. | ||
# But, is this the right way to support it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Animesh: this is probably fine
ghstack-source-id: 1c7cb0a6be2af209d3475dc16af259cbea9c39cc Pull Request resolved: #128996
ghstack-source-id: c2b26eeafd0c37107459f7b875cccd7b283954b5 Pull Request resolved: #128996
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @chauhang @d4l3k @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire