Skip to content

Commit

Permalink
Enable some models with inline_inbuilt_nn_modules (#128315)
Browse files Browse the repository at this point in the history
Summary:
For all models, graph breaks/recompiles reduce.
For drq, it increases and this is a legit one.

X-link: pytorch/pytorch#128315
Approved by: https://github.com/jansel

Reviewed By: izaitsevfb

Differential Revision: D58700178

Pulled By: anijain2305

fbshipit-source-id: 69af5ca4682486b9a418fc59c531e555e89da713

Co-authored-by: Laith Sakka <[email protected]>
  • Loading branch information
2 people authored and facebook-github-bot committed Jun 19, 2024
1 parent 612b3c8 commit 8ab8a3e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
27 changes: 18 additions & 9 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,10 @@ def skip_models_due_to_control_flow(self):
def guard_on_nn_module_models(self):
return set()

@property
def inline_inbuilt_nn_modules_models(self):
return set()

def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
raise NotImplementedError

Expand Down Expand Up @@ -4218,16 +4222,21 @@ def detect_and_mark_batch(t):
if name in runner.guard_on_nn_module_models:
guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True)

inline_ctx = contextlib.nullcontext()
if name in runner.inline_inbuilt_nn_modules_models:
inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)

with guard_ctx:
runner.run_one_model(
name,
model,
example_inputs,
optimize_ctx,
experiment,
explain=args.explain,
tag=args.tag,
)
with inline_ctx:
runner.run_one_model(
name,
model,
example_inputs,
optimize_ctx,
experiment,
explain=args.explain,
tag=args.tag,
)
if args.generate_aot_autograd_stats:
stats_file = output_filename.split(".csv")[0] + "_stats.csv"
output_csv(
Expand Down
13 changes: 13 additions & 0 deletions userbenchmark/dynamo/dynamobench/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def guard_on_nn_module_models(self):
"vision_maskrcnn",
}

@property
def inline_inbuilt_nn_modules_models(self):
return {
"basic_gnn_edgecnn",
"drq",
"hf_Reformer",
"DALLE2_pytorch",
"hf_BigBird",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_maskrcnn_r_101_fpn",
"vision_maskrcnn",
}

def load_model(
self,
device,
Expand Down

0 comments on commit 8ab8a3e

Please sign in to comment.