Skip to content

Commit

Permalink
Enable typechecking for utils.py (#112971)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#112971
Approved by: https://github.com/lezcano, https://github.com/jansel
ghstack dependencies: #112130, #112970

Reviewed By: PaliC

Differential Revision: D51140165

Pulled By: int3

fbshipit-source-id: a0613683717d9e697ba3b5d311f7d6908c46a928
  • Loading branch information
int3 authored and facebook-github-bot committed Nov 10, 2023
1 parent a4e7ede commit 15ce7cd
Showing 1 changed file with 77 additions and 43 deletions.
120 changes: 77 additions & 43 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,29 @@
from contextlib import contextmanager
from functools import lru_cache, wraps
from pathlib import Path
from typing import Any, Dict, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
cast,
ClassVar,
Counter,
DefaultDict,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
Union,
ValuesView,
)


try:
import numpy as np
except ModuleNotFoundError:
np = None
np = None # type: ignore[assignment]

try:
import torch._logging
Expand All @@ -45,7 +61,12 @@

# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
if np:
NP_SUPPORTED_MODULES = (np, np.fft, np.linalg, np.random)
NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = (
np,
np.fft,
np.linalg,
np.random,
)

NP_TO_TNP_MODULE = {
np: tnp,
Expand All @@ -54,7 +75,7 @@
np.random: tnp.random,
}
else:
NP_SUPPORTED_MODULES = {}
NP_SUPPORTED_MODULES = tuple()

NP_TO_TNP_MODULE = {}
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
Expand All @@ -73,17 +94,17 @@
from torch.utils._pytree import tree_map_only


counters = collections.defaultdict(collections.Counter)
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
log = logging.getLogger(__name__)

# profiling compilation time by function
compilation_time_metrics = {}
compilation_time_metrics: Dict[str, List[float]] = {}

# profiling compilation time by frame phase
frame_phase_timing = {}
frame_phase_timing: Dict[str, Dict[str, float]] = {}

timer_counter = itertools.count()

Expand Down Expand Up @@ -172,7 +193,7 @@ def increment_op_count(cnt):
# entire_frame_compile:8.574629999999999
# backend_compile:5.26806
def print_time_report():
total = 0
total = 0.0
total_by_key = {}
for timings in frame_phase_timing.values():
for key, timing in timings.items():
Expand Down Expand Up @@ -378,7 +399,7 @@ def write_record_to_file(filename, exec_record):
with open(filename, "wb") as f:
exec_record.dump(f)
except Exception:
log.error("Unable to write execution record %s", filename, exc_info=1)
log.error("Unable to write execution record %s", filename, exc_info=True)


def count_calls(g: fx.Graph):
Expand Down Expand Up @@ -454,7 +475,7 @@ def is_typing(value):
#
# NB: we intentionally ignore classes that inherit from Generic, since they
# can be used as both TypingVariable as well as UserDefinedClassVariable.
return isinstance(value, typing._Final) or value is typing.Generic
return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined]


def is_numpy_int_type(value):
Expand Down Expand Up @@ -524,7 +545,7 @@ def make_cell(val=None):
def f():
return x

assert len(f.__closure__) == 1
assert f.__closure__ is not None and len(f.__closure__) == 1
return f.__closure__[0]


Expand Down Expand Up @@ -581,6 +602,7 @@ def create(scope, name, val):

class CleanupManager(ExactWeakKeyDictionary):
count = 0
instance: ClassVar["CleanupManager"]

def _remove_id(self, idx):
for hook in self.values[idx]:
Expand Down Expand Up @@ -651,6 +673,7 @@ def torch_clone(x):


def clone_inputs(example_inputs):
res: Union[Dict[Any, Any], List[Any]]
if type(example_inputs) is dict:
res = dict(example_inputs)
for key, value in res.items():
Expand Down Expand Up @@ -757,7 +780,7 @@ class Marker:
# frustrating ones e.g. torch.return_types.max
assert cls.__module__ == "torch.return_types"
obj = cls(map(Marker, range(cls.n_fields)))
fields = [None] * cls.n_fields
fields: List[Optional[str]] = [None] * cls.n_fields
for name in dir(obj):
if name[0] != "_" and isinstance(getattr(obj, name), Marker):
fields[getattr(obj, name).index] = name
Expand Down Expand Up @@ -881,10 +904,10 @@ def check_numpy_ndarray_args(args, kwargs):
)


dict_values = type(dict().values())
odict_values = type(collections.OrderedDict().values())
tuple_iterator = type(iter(tuple()))
tuple_iterator_len = tuple_iterator.__length_hint__
dict_values: Type[ValuesView[Any]] = type(dict().values())
odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values())
tuple_iterator: Type[Iterator[Any]] = type(iter(tuple()))
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
object_new = object.__new__


Expand Down Expand Up @@ -923,23 +946,28 @@ def _get_fake_tensor(vt):


def iter_contains(items, search, tx, check_tensor_identity=False):
from .variables import BuiltinVariable, ConstantVariable, TensorVariable
from .variables import (
BuiltinVariable,
ConstantVariable,
TensorVariable,
VariableTracker,
)

if search.is_python_constant():
found = any(
found_const = any(
x.is_python_constant()
and x.as_python_constant() == search.as_python_constant()
for x in items
)
return ConstantVariable.create(found)
return ConstantVariable.create(found_const)

must_check_tensor_id = False
if check_tensor_identity and isinstance(search, TensorVariable):
must_check_tensor_id = True
# Match of Tensor means match of FakeTensor
search = _get_fake_tensor(search)

found = None
found: Optional[VariableTracker] = None
for x in items:
if must_check_tensor_id:
if isinstance(x, TensorVariable):
Expand Down Expand Up @@ -1255,10 +1283,10 @@ def disable_cache_limit():
orig_code_map = ExactWeakKeyDictionary()

# keep a record of code_obj -> list of guard failure reasons for logging
guard_failures = collections.defaultdict(list)
guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list)

# Keep a record of graph break reasons for logging
graph_break_reasons = list()
graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReasons"] = list()

# keep record of compiled code, if we are in "error if recompile"
# to track code that dynamo has compiled previously
Expand Down Expand Up @@ -1385,6 +1413,8 @@ def extract_fake_example_value(node, required=True):
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]
elif required:
from torch._dynamo.exc import unimplemented

unimplemented("`FakeTensor` example value was required but not available")
else:
return None
Expand Down Expand Up @@ -1456,7 +1486,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
except Unsupported:
raise
except RuntimeError as e:
cause = e
cause: BaseException = e
if e.__cause__ is not None:
cause = e.__cause__

Expand Down Expand Up @@ -1636,7 +1666,7 @@ def import_submodule(mod: types.ModuleType):
"""
Ensure all the files in a given submodule are imported
"""
for filename in sorted(os.listdir(os.path.dirname(mod.__file__))):
for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))):
if filename.endswith(".py") and filename[0] != "_":
importlib.import_module(f"{mod.__name__}.{filename[:-3]}")

Expand Down Expand Up @@ -1681,8 +1711,10 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):


def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any], is_tensor: bool, guard_source: "GuardSource"
) -> Tuple[bool, TensorStaticReason]:
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: "torch._guards.GuardSource",
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
Expand Down Expand Up @@ -1825,6 +1857,7 @@ def to_numpy_helper(value):

def numpy_to_tensor(value):
"""Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert."""
assert np is not None
if isinstance(value, np.ndarray):
return torch.as_tensor(value)
if isinstance(value, tnp.ndarray):
Expand Down Expand Up @@ -1879,7 +1912,7 @@ def __call__(self, *args, **kwargs):
class numpy_operator_wrapper:
"""Implements dunder methods for tnp.ndarray via functions from the operator library"""

def __init__(self, op: str):
def __init__(self, op: Callable[..., Any]):
self.op = op
self.__name__ = f"wrapped_{op.__name__}"

Expand Down Expand Up @@ -2052,7 +2085,7 @@ def nextline(lineno, col):
# left^^^^^ right^^^^^
# -2 since end_lineno is 1-indexed and because we added an extra
# bracket to `segment` when calling ast.parse
cur_lineno = expr.left.end_lineno - 2
cur_lineno = cast(int, expr.left.end_lineno) - 2
cur_col = normalize(cur_lineno, expr.left.end_col_offset)
cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col)

Expand Down Expand Up @@ -2084,27 +2117,27 @@ def nextline(lineno, col):
# value^^^^^ slice^^^^^
# subscript^^^^^^^^^^^^^^^^^^^^
# find left bracket (first '[' after value)
left_lineno = expr.value.end_lineno - 2
left_lineno = cast(int, expr.value.end_lineno) - 2
left_col = normalize(left_lineno, expr.value.end_col_offset)
left_lineno, left_col = next_valid_char(left_lineno, left_col)
while lines[left_lineno][left_col] != "[":
left_lineno, left_col = increment(left_lineno, left_col)
# find right bracket (final character of expression)
right_lineno = expr.end_lineno - 2
right_lineno = cast(int, expr.end_lineno) - 2
right_col = normalize(right_lineno, expr.end_col_offset)
return _Anchors(left_lineno, left_col, right_lineno, right_col)
elif isinstance(expr, ast.Call):
# ( func_expr ) (args, kwargs)
# func^^^^^
# call^^^^^^^^^^^^^^^^^^^^^^^^
# find left bracket (first '(' after func)
left_lineno = expr.func.end_lineno - 2
left_lineno = cast(int, expr.func.end_lineno) - 2
left_col = normalize(left_lineno, expr.func.end_col_offset)
left_lineno, left_col = next_valid_char(left_lineno, left_col)
while lines[left_lineno][left_col] != "(":
left_lineno, left_col = increment(left_lineno, left_col)
# find right bracket (final character of expression)
right_lineno = expr.end_lineno - 2
right_lineno = cast(int, expr.end_lineno) - 2
right_col = normalize(right_lineno, expr.end_col_offset)
return _Anchors(left_lineno, left_col, right_lineno, right_col)

Expand All @@ -2126,6 +2159,7 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s
Python's `traceback` module doesn't handle multi-line expressions
(and their anchor extraction code is not completely correct).
"""
assert inst.positions is not None
if inst.positions.lineno is None:
return ""
# The rstrip + "\n" pattern is used throughout this function to handle
Expand Down Expand Up @@ -2180,7 +2214,7 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s
markers = [marker.replace("~", "^") for marker in markers]
else:
# make markers mutable
markers = [list(marker) for marker in markers]
mutable_markers: List[List[str]] = [list(marker) for marker in markers]

# anchor positions do not take start_offset into account
if anchors.left_end_lineno == 0:
Expand All @@ -2189,24 +2223,24 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s
anchors.right_start_offset += start_offset

# Turn `~`` markers between anchors to `^`
for line in range(len(markers)):
for col in range(len(markers[line])):
if line < anchors.left_end_lineno:
for lineno in range(len(markers)):
for col in range(len(mutable_markers[lineno])):
if lineno < anchors.left_end_lineno:
continue
if line == anchors.left_end_lineno and col < anchors.left_end_offset:
if lineno == anchors.left_end_lineno and col < anchors.left_end_offset:
continue
if (
line == anchors.right_start_lineno
lineno == anchors.right_start_lineno
and col >= anchors.right_start_offset
):
continue
if line > anchors.right_start_lineno:
if lineno > anchors.right_start_lineno:
continue
if markers[line][col] == "~":
markers[line][col] = "^"
if mutable_markers[lineno][col] == "~":
mutable_markers[lineno][col] = "^"

# make markers into strings again
markers = ["".join(marker) for marker in markers]
markers = ["".join(marker) for marker in mutable_markers]

result = ""
for i in range(len(markers)):
Expand Down Expand Up @@ -2247,7 +2281,7 @@ def is_tensor_base_attr_getter(value):
return (
isinstance(value, types.MethodWrapperType)
and value.__name__ == "__get__"
and value.__self__.__objclass__ is torch._C._TensorBase
and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined]
)


Expand Down

0 comments on commit 15ce7cd

Please sign in to comment.