Skip to content

Commit

Permalink
Decompositions for upsample linear backward
Browse files Browse the repository at this point in the history
  • Loading branch information
isuruf committed Apr 1, 2024
1 parent 23faab8 commit ecef740
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
3 changes: 3 additions & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.upsample_linear1d,
aten.upsample_bilinear2d,
aten.upsample_trilinear3d,
aten.upsample_linear1d_backward,
aten.upsample_bilinear2d_backward,
aten.upsample_trilinear3d_backward,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
Expand Down
103 changes: 102 additions & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3415,7 +3415,6 @@ def upsample_linear1d(
@register_decomposition(
[aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
)
@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
def upsample_bilinear2d(
input: Tensor,
Expand Down Expand Up @@ -3531,6 +3530,108 @@ def get_values(inp_size, out_size, scales, nsqueeze):
return result


@register_decomposition(
[
aten.upsample_linear1d_backward,
aten.upsample_bilinear2d_backward,
aten.upsample_trilinear3d_backward,
]
)
@pw_cast_for_opmath
def _upsample_linear_backward(
grad_output: Tensor,
output_size: List[int],
input_size: List[int],
align_corners: bool,
) -> Tensor:
# get dimensions of original image
n_batch, n_channels = input_size[:2]
inp_sizes = input_size[2:]
n_dims = len(inp_sizes)
scales = [None] * n_dims

_, dtype = utils.elementwise_dtypes(
grad_output,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)

def get_values(inp_size, out_size, scales, nsqueeze):
# First Calculate scaling factor
scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
# We have to create arange with int64 dtype and use .to in order to avoid
# additional kernels creation in inductor and get a perf slowdown
i = torch.arange(out_size, device=grad_output.device).to(dtype=dtype)

x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
x = x_f32.to(torch.int64)
xp1 = (x + 1).clamp(max=inp_size - 1)
return x_f32, x, xp1

values = [
get_values(inp_size, out_size, scales, n_dims - 1 - i)
for i, (inp_size, out_size, scales) in enumerate(
zip(inp_sizes, output_size, scales)
)
]
xs_f32, xs, xp1s = list(zip(*values))

coeffs = [1]
for i in range(n_dims):
xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
new_coeffs: List[Any] = [None] * (2 * len(coeffs))
new_coeffs[::2] = [torch.mul(coeff, (1 - xscale)) for coeff in coeffs]
new_coeffs[1::2] = [torch.mul(coeff, xscale) for coeff in coeffs]
coeffs = new_coeffs

output_view_shape = [
n_batch,
n_channels,
functools.reduce(operator.mul, output_size),
]
inp_last_dim_size = functools.reduce(operator.mul, input_size[2:])
input_view_shape = [
n_batch,
n_channels,
inp_last_dim_size,
]

indices = []
coeff_values = []
result = grad_output.new_zeros(input_size)
strides = result.stride()[2:]

for i, (coeff, a) in enumerate(zip(coeffs, product(*[[0, 1]] * n_dims))):
idx = sum(
[
xs[k] * strides[k] if a[k] == 0 else xp1s[k] * strides[k]
for k in range(n_dims)
]
)
coeff_values.append((coeff * grad_output).view(output_view_shape))
indices.append(idx.view(output_view_shape[-1]))

result = aten._unsafe_index_put(
result.view(input_view_shape),
[None, None, aten.concat(indices)],
aten.concat(coeff_values, dim=2),
accumulate=True,
).view(input_size)

# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(grad_output)

# following "heuristic: only use channels_last path when it's faster than the contiguous path"
if grad_output.device.type == "cuda" and n_channels < 16:
memory_format = torch.contiguous_format

assert isinstance(result, torch.Tensor)

result = result.contiguous(memory_format=memory_format)

return result


# We should be applying decompositions after all transformations
@register_decomposition(aten.is_same_size.default)
def is_same_size(a: Tensor, b: Tensor) -> bool:
Expand Down

0 comments on commit ecef740

Please sign in to comment.