Skip to content

Commit

Permalink
Decompositions for upsample linear backward
Browse files Browse the repository at this point in the history
ghstack-source-id: 7ffa4f4508b05aad9f30fd6c8c1a95f76f404ad5
Pull Request resolved: #123222
  • Loading branch information
isuruf committed May 13, 2024
1 parent 1b25ad2 commit d61cc60
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 9 deletions.
6 changes: 0 additions & 6 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -1325,18 +1325,12 @@ aten::unsqueeze_copy
aten::unsqueeze_copy.out
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward
aten::upsample_bilinear2d_backward.grad_input
aten::upsample_linear1d_backward
aten::upsample_linear1d_backward.grad_input
aten::upsample_nearest1d_backward
aten::upsample_nearest1d_backward.grad_input
aten::upsample_nearest2d_backward
aten::upsample_nearest2d_backward.grad_input
aten::upsample_nearest3d_backward
aten::upsample_nearest3d_backward.grad_input
aten::upsample_trilinear3d_backward
aten::upsample_trilinear3d_backward.grad_input
aten::values
aten::values_copy
aten::values_copy.out
Expand Down
3 changes: 3 additions & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.upsample_linear1d,
aten.upsample_bilinear2d,
aten.upsample_trilinear3d,
aten.upsample_linear1d_backward,
aten.upsample_bilinear2d_backward,
aten.upsample_trilinear3d_backward,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
Expand Down
126 changes: 125 additions & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3538,7 +3538,6 @@ def upsample_linear1d(
@register_decomposition(
[aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
)
@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
def upsample_bilinear2d(
input: Tensor,
Expand Down Expand Up @@ -3673,6 +3672,131 @@ def get_values(inp_size, out_size, scales, nsqueeze):
return result


@register_decomposition(aten.upsample_linear1d_backward)
@out_wrapper("grad_input")
def upsample_linear1d_backward(
grad_output: Tensor,
output_size: List[int],
input_size: List[int],
align_corners: bool,
scales_w: Optional[float] = None,
) -> Tensor:
return _upsample_linear_backward(
grad_output, output_size, input_size, align_corners, [scales_w]
)


@register_decomposition(aten.upsample_bilinear2d_backward)
@out_wrapper("grad_input")
def upsample_bilinear2d_backward(
grad_output: Tensor,
output_size: List[int],
input_size: List[int],
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
return _upsample_linear_backward(
grad_output, output_size, input_size, align_corners, [scales_h, scales_w]
)


@register_decomposition(aten.upsample_trilinear3d_backward)
@out_wrapper("grad_input")
def upsample_trilinear3d_backward(
grad_output: Tensor,
output_size: List[int],
input_size: List[int],
align_corners: bool,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
return _upsample_linear_backward(
grad_output,
output_size,
input_size,
align_corners,
[scales_d, scales_h, scales_w],
)


@pw_cast_for_opmath
def _upsample_linear_backward(
grad_output: Tensor,
output_size: List[int],
input_size: List[int],
align_corners: bool,
scales: List[Optional[float]],
) -> Tensor:
# get dimensions of original image
n_batch, n_channels = input_size[:2]
inp_sizes = input_size[2:]
n_dims = len(inp_sizes)

_, dtype = utils.elementwise_dtypes(
grad_output,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)

def get_values(inp_size, out_size, scales, nsqueeze):
# First Calculate scaling factor
scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
# We have to create arange with int64 dtype and use .to in order to avoid
# additional kernels creation in inductor and get a perf slowdown
i = torch.arange(out_size, device=grad_output.device).to(dtype=dtype)

x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
x = x_f32.to(torch.int64)
xp1 = (x + 1).clamp(max=inp_size - 1)
return x_f32, x, xp1

values = [
get_values(inp_size, out_size, scales, n_dims - 1 - i)
for i, (inp_size, out_size, scales) in enumerate(
zip(inp_sizes, output_size, scales)
)
]
xs_f32, xs, xp1s = list(zip(*values))

# Using functions here to help inductor do fusions easier.
# This results in (a * (b * (c * x)) instead of (a * (b * c)) * x
# in the generated code and therefore inductor does not try to
# create temporaries for (b * c).
coeff_fns = [lambda x: x]
for i in range(n_dims):
xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
new_coeff_fns: List[Any] = [None] * (2 * len(coeff_fns))
new_coeff_fns[::2] = [
lambda x: torch.mul(coeff_fn(x), (1 - xscale)) for coeff_fn in coeff_fns
]
new_coeff_fns[1::2] = [
lambda x: torch.mul(coeff_fn(x), xscale) for coeff_fn in coeff_fns
]
coeff_fns = new_coeff_fns

result = grad_output.new_zeros(input_size)
for coeff_fn, a in zip(coeff_fns, product(*[[0, 1]] * n_dims)):
idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)]
result = aten._unsafe_index_put(
result, idx, coeff_fn(grad_output), accumulate=True
)

# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(grad_output)

# following "heuristic: only use channels_last path when it's faster than the contiguous path"
if grad_output.device.type == "cuda" and n_channels < 16:
memory_format = torch.contiguous_format

assert isinstance(result, torch.Tensor)

result = result.contiguous(memory_format=memory_format)

return result


# We should be applying decompositions after all transformations
@register_decomposition(aten.is_same_size.default)
def is_same_size(a: Tensor, b: Tensor) -> bool:
Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,9 +2143,7 @@ def is_aligned(x):
make_fallback(aten.fractional_max_pool3d_backward)
make_fallback(aten.replication_pad1d_backward)
make_fallback(aten.replication_pad2d_backward)
make_fallback(aten.upsample_linear1d_backward)
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
make_fallback(aten.upsample_trilinear3d_backward)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
make_fallback(aten._pdist_backward)

Expand Down

0 comments on commit d61cc60

Please sign in to comment.