Skip to content

Commit

Permalink
Enable typechecking for testing.py (#112129)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#112129
Approved by: https://github.com/Skylion007
ghstack dependencies: #111894, #111992, #112031, #112127, #112128

Reviewed By: izaitsevfb

Differential Revision: D50778824

Pulled By: int3

fbshipit-source-id: 53402a367156939d24a823d62eedb5a31846bba8
  • Loading branch information
int3 authored and facebook-github-bot committed Oct 30, 2023
1 parent 97e7f0d commit 94126be
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import sys
import types
import unittest
from typing import Sequence, Union
from typing import List, Optional, Sequence, Union
from unittest.mock import patch

np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
Expand Down Expand Up @@ -62,7 +63,7 @@ def named_buffers_for_optimized_module(mod):
return mod._orig_mod.named_buffers


def remove_optimized_module_prefix(name):
def remove_optimized_module_prefix(name) -> str:
return re.sub(r"^_orig_mod[.]", "", name)


Expand Down Expand Up @@ -140,21 +141,21 @@ def reduce_to_scalar_loss(out):
raise NotImplementedError("Don't know how to reduce", type(out))


def debug_dir():
def debug_dir() -> str:
path = os.path.join(os.path.dirname(__file__), "../debug")
if not os.path.exists(path):
os.mkdir(path)
return path


def debug_dump(name, code: types.CodeType, extra=""):
def debug_dump(name, code: types.CodeType, extra="") -> None:
with open(os.path.join(debug_dir(), name), "w") as fd:
fd.write(
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
)


def debug_insert_nops(frame, cache_size, hooks, _):
def debug_insert_nops(frame, cache_size, hooks, _) -> Optional[GuardedCode]:
"""used to debug jump updates"""

def insert_nops(instructions, code_options):
Expand Down Expand Up @@ -187,7 +188,7 @@ def __init__(self):
self.frame_count = 0
self.op_count = 0

def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
Expand All @@ -206,7 +207,7 @@ def __init__(self, backend):
self.backend = backend
self.graphs = []

def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
from .backends.registry import lookup_backend

self.frame_count += 1
Expand All @@ -223,21 +224,21 @@ class EagerAndRecordGraphs:
def __init__(self):
self.graphs = []

def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.graphs.append(gm)
return gm


def strip_comment(code):
def strip_comment(code) -> str:
code = str(code)
return re.sub(r"(?m)^ *#.*\n?", "", code)


def remove_trailing_space(code):
def remove_trailing_space(code) -> str:
return "\n".join([line.rstrip() for line in code.split("\n")])


def normalize_gm(gm_str):
def normalize_gm(gm_str) -> str:
# strip comments as comments have path to files which may differ from
# system to system.
return remove_trailing_space(strip_comment(gm_str))
Expand All @@ -252,7 +253,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
expected = CompileCounter()
try:
gm = torch.fx.symbolic_trace(fn)
expected(gm)
expected(gm) # type: ignore[call-arg] # FIXME: https://github.com/pytorch/pytorch/issues/112230
print("\nfx.symbolic_trace graph:")
gm.graph.print_tabular()
expected_ops = expected.op_count
Expand Down

0 comments on commit 94126be

Please sign in to comment.