From 6a02e0c4a89d3fbb6995299dd745369948ee99ce Mon Sep 17 00:00:00 2001 From: "Animesh Jain (Meta Employee)" Date: Wed, 5 Jun 2024 09:49:00 -0700 Subject: [PATCH] Unspec nn module when global backward hooks are present (#127802) Summary: X-link: https://github.com/pytorch/pytorch/pull/127802 Approved by: https://github.com/jansel ghstack dependencies: #127785 Reviewed By: atalman Differential Revision: D58168157 Pulled By: anijain2305 fbshipit-source-id: ab9bf505439b5df39a2c45447427744db16028a6 --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 2b42c8dec6..04f01c7576 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -2106,6 +2106,14 @@ def format_bytecode(prefix, name, filename, line_no, code): all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks + ) + + def nn_module_get_all_hooks( mod, check_forward_hooks=False,