Skip to content

Commit

Permalink
[BugFix] Remove monkey patching of uninit params (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 21, 2024
1 parent 8ea4e87 commit 50f4577
Showing 1 changed file with 11 additions and 55 deletions.
66 changes: 11 additions & 55 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
lazy_legacy,
set_lazy_legacy,
)
from torch import _C, Tensor
from torch import Tensor
from torch.nn.parameter import (
UninitializedBuffer,
UninitializedParameter,
Expand All @@ -31,7 +31,6 @@

TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
UNINIT_TENSOR_FUNCTIONS: dict[Callable, Callable] = {}
T = TypeVar("T", bound="TensorDictBase")


Expand All @@ -57,19 +56,6 @@ def decorator(func: Callable) -> Callable:
return decorator


def implements_for_uninit_param(
torch_function: Callable,
) -> Callable[[Callable], Callable]:
"""Register a torch function override for UninitializedTensorMixin."""

@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
UNINIT_TENSOR_FUNCTIONS[torch_function] = func
return func

return decorator


@implements_for_td(torch.unbind)
def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.unbind(*args, **kwargs)
Expand Down Expand Up @@ -424,26 +410,32 @@ def _stack(
out = {}
for key in keys:
out[key] = []
is_lazy = False
tensor_shape = None
for _tensordict in list_of_tensordicts:
tensor = _tensordict._get_str(key, default=NO_DEFAULT)
if isinstance(tensor, UninitializedTensorMixin):
pass
is_lazy = True
elif tensor_shape is None:
tensor_shape = tensor.shape
elif tensor.shape != tensor_shape:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
out[key].append(tensor)
out[key] = (out[key], is_lazy)

def stack_fn(key_values):
key, values = key_values
def stack_fn(key, values, is_lazy):
if is_lazy:
return _stack_uninit_params(values, dim)
with _ErrorInteceptor(
key, "Attempted to stack tensors on different devices at key"
):
return torch.stack(values, dim)

out = {key: stack_fn((key, value)) for key, value in out.items()}
out = {
key: stack_fn(key, values, is_lazy)
for key, (values, is_lazy) in out.items()
}

return TensorDict(
out,
Expand Down Expand Up @@ -520,42 +512,6 @@ def where(condition, input, other, *, out=None):
return input.where(condition, other, out=out)


# monkey patch
__prev_torch_function__ = UninitializedTensorMixin.__torch_function__


def __torch_function__(
cls,
func: Callable,
types: tuple[type, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Callable:
if kwargs is None:
kwargs = {}
fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None)
if fnc_uninit is not None:
return fnc_uninit(*args, **kwargs)
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None:
kwargs = {}
with _C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# Ideally we'd like to use this from the original __torch_function__
# return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. "
"This error happens when you are using a `LazyModule` or "
f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` "
"objects. When using LazyModules Call `forward` with a dummy batch "
"to initialize the parameters before calling torch functions"
)


UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__)


@implements_for_uninit_param(torch.stack)
def _stack_uninit_params(list_of_params, dim=0, out=None):
if out is not None:
raise NotImplementedError
Expand Down

0 comments on commit 50f4577

Please sign in to comment.