Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompositions for upsample linear backward #123222

Open
wants to merge 22 commits into
base: gh/isuruf/37/base
Choose a base branch
from

Conversation

Copy link

pytorch-bot bot commented Apr 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123222

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 349827a with merge base c888ee3 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

isuruf added a commit that referenced this pull request Apr 2, 2024
ghstack-source-id: 8f8ea4ddbdf5d7d000d00448bc711265f343d549
Pull Request resolved: #123222
@isuruf
Copy link
Collaborator Author

isuruf commented Apr 2, 2024

Need triton-lang/triton#3491 for performance

@lezcano
Copy link
Collaborator

lezcano commented Apr 3, 2024

need to remove the fallback for these.
Do we generate code competitive with eager once triton-lang/triton#3491 is merged?

isuruf added a commit that referenced this pull request Apr 3, 2024
ghstack-source-id: 690147096db2c66c8c38b73c6b82e1592a92d7af
Pull Request resolved: #123222
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 4, 2024
ghstack-source-id: 3e37a52f488305de4387e57c48332d6813698c24
Pull Request resolved: #123222
@isuruf
Copy link
Collaborator Author

isuruf commented Apr 4, 2024

With the latest commit, here are some benchmarks. (I'm not sure what the appropriate values are for a benchmark here)

[------------------------- upsample_linear_backward -------------------------]
                                                       |  Decomposed  |  Eager
12 threads: ------------------------------------------------------------------
      {'scaling': 2, 'shape': (3, 5, 10000000)}        |     5.55     |   5.33
      {'scaling': 2, 'shape': (3, 5, 10000, 1000)}     |     9.99     |  10.58
      {'scaling': 2, 'shape': (3, 5, 100, 1000, 100)}  |    22.30     |  22.11

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An idea: I think we can do all this just with one index_put_ and broadcasting.
There are a few ideas here:

  • We have 2^n indices (and coefficients) of length d_i each to put together in one tensor.
  • To do that, the idea is to put together n 1-dimensional tensors of length 2*d_i

The naïve way of doing this would be via torch.cat, but that doesn't fuse well in inductor. So, we are going to do it playing around with the indexing.

To do cat(arange(n), arange(n) we can do arange(2n) % n. With that we can get a x2_f32 = cat(x_f32, x_f32).
To get cat(x, xp1), we can do (x_f32.to(int64) + (arange(2n) // n))..clamp(max=inp_size - 1).

We do that for each dimension and broadcasting we get all the indices.

With pretty much the same tricks we can get the coefficients and then do one index_put.

Would be great to see if this can generate just one performant kernel!

nb. the same tricks apply to the forward of this op, if I'm not mistaken. Funnily enough, this just occurred to me. I didn't manage to write this in one kernel a year ago when I implemented decomps for a couple of these ops :D

@isuruf
Copy link
Collaborator Author

isuruf commented Apr 5, 2024

Thanks for the suggestion. I tried that method earlier in the week. isuruf@ecef740. I checked that the triton code generated was one single kernel.
The performance of it was really bad. (2.5x slowdown compared to eager for 2D).
The issue there is that instead of iterating over the indices of grad_input, now the triton kernel iterates over size(grad_input)*2^n indices which results in a slowdown.

[------------------------- upsample_linear_backward -------------------------]
                                                       |  Decomposed  |  Eager
12 threads: ------------------------------------------------------------------
      {'scaling': 2, 'shape': (3, 5, 10000000)}        |     6.54     |   5.34
      {'scaling': 2, 'shape': (3, 5, 10000, 1000)}     |    25.76     |  10.57
      {'scaling': 2, 'shape': (3, 5, 100, 1000, 100)}  |    56.83     |  22.09

@lezcano
Copy link
Collaborator

lezcano commented Apr 5, 2024

That's fair enough. Looking at the eager implementation, I see we end up doing it like your current approach, where each thread does 2^n atomic_adds.

Mind posting the triton kernels with the current approach and the benchmarking script?

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 5, 2024
ghstack-source-id: 01df566cd9ede5e5ec05876513d92d4d04417163
Pull Request resolved: #123222
@isuruf
Copy link
Collaborator Author

isuruf commented Apr 5, 2024

3D:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/sh/csh7njex4jcmtkekmbkknh5jednzydzaderybrvqvfdldqetdp57.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[33554432], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 18750000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = 0.0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/26/c26lclpgh3vwdbju5ij5h7zyr3pgu42s246arjplhyyedtcatj2n.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[268435456], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 150000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = (xindex // 100000) % 100
    x1 = (xindex // 100) % 1000
    x0 = xindex % 100
    x4 = xindex
    x3 = (xindex // 10000000)
    tmp18 = tl.load(in_ptr0 + (x4), xmask)
    tmp0 = x2
    tmp1 = tmp0.to(tl.float32)
    tmp2 = 0.494949494949495
    tmp3 = tmp1 * tmp2
    tmp4 = 0.0
    tmp5 = triton_helpers.maximum(tmp3, tmp4)
    tmp6 = tmp5.to(tl.int32)
    tmp7 = x1
    tmp8 = tmp7.to(tl.float32)
    tmp9 = 0.4994994994994995
    tmp10 = tmp8 * tmp9
    tmp11 = triton_helpers.maximum(tmp10, tmp4)
    tmp12 = tmp11.to(tl.int32)
    tmp13 = x0
    tmp14 = tmp13.to(tl.float32)
    tmp15 = tmp14 * tmp2
    tmp16 = triton_helpers.maximum(tmp15, tmp4)
    tmp17 = tmp16.to(tl.int32)
    tmp19 = tmp17.to(tl.float32)
    tmp20 = tmp16 - tmp19
    tmp21 = triton_helpers.maximum(tmp20, tmp4)
    tmp22 = 1.0
    tmp23 = triton_helpers.minimum(tmp21, tmp22)
    tmp24 = tmp18 * tmp23
    tmp25 = tmp24 * tmp23
    tmp26 = tmp22 - tmp23
    tmp27 = tmp25 * tmp26
    tmp28 = tl.full([1], 1, tl.int64)
    tmp29 = tmp17 + tmp28
    tmp30 = tl.full([1], 49, tl.int64)
    tmp31 = triton_helpers.minimum(tmp29, tmp30)
    tmp32 = tmp25 * tmp23
    tmp33 = tmp12 + tmp28
    tmp34 = tl.full([1], 499, tl.int64)
    tmp35 = triton_helpers.minimum(tmp33, tmp34)
    tmp36 = tmp6 + tmp28
    tmp37 = triton_helpers.minimum(tmp36, tmp30)
    tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp12) + (25000*tmp6) + (1250000*x3)), tmp27, xmask)
    tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp12) + (25000*tmp6) + (1250000*x3)), tmp32, xmask)
    tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp35) + (25000*tmp6) + (1250000*x3)), tmp27, xmask)
    tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp35) + (25000*tmp6) + (1250000*x3)), tmp32, xmask)
    tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp12) + (25000*tmp37) + (1250000*x3)), tmp27, xmask)
    tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp12) + (25000*tmp37) + (1250000*x3)), tmp32, xmask)
    tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp35) + (25000*tmp37) + (1250000*x3)), tmp27, xmask)
    tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp35) + (25000*tmp37) + (1250000*x3)), tmp32, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    args_1, = args
    args.clear()
    assert_size_stride(args_1, (3, 5, 100, 1000, 100), (50000000, 10000000, 100000, 100, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((3, 5, 50, 500, 50), (6250000, 1250000, 25000, 50, 1), torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(buf0, 18750000, grid=grid(18750000), stream=stream0)
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
        del args_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    args_1 = rand_strided((3, 5, 100, 1000, 100), (50000000, 10000000, 100000, 100, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([args_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

2D:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/by/cbyr7653xub6dabl6wfbowmcedms4u3auahw3nwzyho4ewtg57bj.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[67108864], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 37500000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = 0.0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/4f/c4f7tg334eiuon7li3rdg4g6ciqrjsa7bq5ye7lwmoiqkop4onwg.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[268435456], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 150000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 1000) % 10000
    x0 = xindex % 1000
    x3 = xindex
    x2 = (xindex // 10000000)
    tmp13 = tl.load(in_ptr0 + (x3), xmask)
    tmp0 = x1
    tmp1 = tmp0.to(tl.float32)
    tmp2 = 0.49994999499949994
    tmp3 = tmp1 * tmp2
    tmp4 = 0.0
    tmp5 = triton_helpers.maximum(tmp3, tmp4)
    tmp6 = tmp5.to(tl.int32)
    tmp7 = x0
    tmp8 = tmp7.to(tl.float32)
    tmp9 = 0.4994994994994995
    tmp10 = tmp8 * tmp9
    tmp11 = triton_helpers.maximum(tmp10, tmp4)
    tmp12 = tmp11.to(tl.int32)
    tmp14 = tmp12.to(tl.float32)
    tmp15 = tmp11 - tmp14
    tmp16 = triton_helpers.maximum(tmp15, tmp4)
    tmp17 = 1.0
    tmp18 = triton_helpers.minimum(tmp16, tmp17)
    tmp19 = tmp13 * tmp18
    tmp20 = tmp17 - tmp18
    tmp21 = tmp19 * tmp20
    tmp22 = tl.full([1], 1, tl.int64)
    tmp23 = tmp12 + tmp22
    tmp24 = tl.full([1], 499, tl.int64)
    tmp25 = triton_helpers.minimum(tmp23, tmp24)
    tmp26 = tmp19 * tmp18
    tmp27 = tmp6 + tmp22
    tmp28 = tl.full([1], 4999, tl.int64)
    tmp29 = triton_helpers.minimum(tmp27, tmp28)
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (2500000*x2)), tmp21, xmask)
    tl.atomic_add(out_ptr0 + (tmp25 + (500*tmp6) + (2500000*x2)), tmp26, xmask)
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp29) + (2500000*x2)), tmp21, xmask)
    tl.atomic_add(out_ptr0 + (tmp25 + (500*tmp29) + (2500000*x2)), tmp26, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    args_1, = args
    args.clear()
    assert_size_stride(args_1, (3, 5, 10000, 1000), (50000000, 10000000, 1000, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((3, 5, 5000, 500), (12500000, 2500000, 500, 1), torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(buf0, 37500000, grid=grid(37500000), stream=stream0)
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
        del args_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    args_1 = rand_strided((3, 5, 10000, 1000), (50000000, 10000000, 1000, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([args_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

1D:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/62/c62wpwuw5yo643cfkemlzsyfuzwdrxgqz4xoowq72dxdc2xht5ct.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[134217728], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 75000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = 0.0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/cg/ccgowdbembfggavdbc3yb3w6ymffzvkkluje3w3vdlqlssfgobua.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[268435456], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 150000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 10000000
    x2 = xindex
    x1 = (xindex // 10000000)
    tmp7 = tl.load(in_ptr0 + (x2), xmask)
    tmp0 = x0
    tmp1 = tmp0.to(tl.float32)
    tmp2 = 0.499999949999995
    tmp3 = tmp1 * tmp2
    tmp4 = 0.0
    tmp5 = triton_helpers.maximum(tmp3, tmp4)
    tmp6 = tmp5.to(tl.int32)
    tmp8 = tmp6.to(tl.float32)
    tmp9 = tmp5 - tmp8
    tmp10 = triton_helpers.maximum(tmp9, tmp4)
    tmp11 = 1.0
    tmp12 = triton_helpers.minimum(tmp10, tmp11)
    tmp13 = tmp11 - tmp12
    tmp14 = tmp7 * tmp13
    tmp15 = tl.full([1], 1, tl.int64)
    tmp16 = tmp6 + tmp15
    tmp17 = tl.full([1], 4999999, tl.int64)
    tmp18 = triton_helpers.minimum(tmp16, tmp17)
    tmp19 = tmp7 * tmp12
    tl.atomic_add(out_ptr0 + (tmp6 + (5000000*x1)), tmp14, xmask)
    tl.atomic_add(out_ptr0 + (tmp18 + (5000000*x1)), tmp19, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    args_1, = args
    args.clear()
    assert_size_stride(args_1, (3, 5, 10000000), (50000000, 10000000, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((3, 5, 5000000), (25000000, 5000000, 1), torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(buf0, 75000000, grid=grid(75000000), stream=stream0)
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
        del args_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    args_1 = rand_strided((3, 5, 10000000), (50000000, 10000000, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([args_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Benchmark script:

import torch
from torch.testing import make_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
from torch._inductor.compile_fx import compile_fx_inner, cudagraphify_impl
from torch._inductor.decomposition import decompositions
from itertools import product
from functools import partial

aten = torch.ops.aten

torch._logging.set_logs(output_code=True)

benchmark_name = "upsample_linear_backward"
Ss = [
    [3, 5, 10000000, 2],
    [3, 5, 10000, 1000, 2],
    [3, 5, 100, 1000, 100, 2],
]


def gen_inputs():
    for inp in Ss:
        shape = inp[:-1]
        yield [torch.randn(shape, dtype=torch.float32, device="cuda"), inp[-1:]]


def benchmark(label, f, x, kwargs):
    kwargs = kwargs.copy()
    kwargs["shape"] = tuple(x.shape)
    return Timer(
        "f([x,])",
        globals=locals(),
        label=benchmark_name,
        description=label,
        sub_label=str(kwargs),
        num_threads=torch.get_num_threads(),
    ).blocked_autorange(min_run_time=4)


def compare(args):
    x, props = args
    print(f"{tuple(x.shape)}")
    if x.dim() == 5:
        aten_func = aten.upsample_trilinear3d_backward
    elif x.dim() == 4:
        aten_func = aten.upsample_bilinear2d_backward
    else:
        aten_func = aten.upsample_linear1d_backward

    kwargs = {
        "scaling": props[0],
    }

    inp_size = [*x.shape[:2], *[d//props[0] for d in x.shape[2:]]]

    def f(args):
        (x,) = args
        val = aten_func(x, x.shape[2:], inp_size, align_corners=True)
        return (val,)

    args = [x]

    decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(args)
    compiled_decomposed = compile_fx_inner(decomposed, args, cudagraphs=False)
    yield benchmark(f"Decomposed", compiled_decomposed, *args, kwargs)

    # non_decomposed = make_fx(f, tracing_mode="fake")(args)
    # compiled_nondecomposed = compile_fx_inner(non_decomposed, args, cudagraphs=False)
    # yield benchmark("Lowering", compiled_nondecomposed, *args, kwargs)
    # Just show the first two generated kernels

    cuda_f = cudagraphify_impl(f, args, static_input_idxs=tuple(range(len(args))))
    yield benchmark(f"Eager", cuda_f, *args, kwargs)


results = []
for x in gen_inputs():
    for res in compare(x):
        results.append(res)

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code and benchmarks look great! Quite interesting the trick you used to generate the kernel... But also sad that we need to do this manually and Inductor doesn't do it on its own.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 8, 2024
ghstack-source-id: 5d316aa9b5636181536effd913e4ff141ffd266d
Pull Request resolved: #123222
@@ -2330,9 +2330,7 @@ def is_aligned(x):
make_fallback(aten.fractional_max_pool3d_backward)
make_fallback(aten.replication_pad1d_backward)
make_fallback(aten.replication_pad2d_backward)
make_fallback(aten.upsample_linear1d_backward)
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not this one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep linear ones for this PR and do bicubic in a follow up PR

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 10fbbda017a05e5a3281b230f1a04b459c42dcd2
Pull Request resolved: #123222
[ghstack-poisoned]
Comment on lines 3750 to 3765
# Using functions here to help inductor do fusions easier.
# This results in (a * (b * (c * x)) instead of (a * (b * c)) * x
# in the generated code and therefore inductor does not try to
# create temporaries for (b * c).
coeff_fns = [lambda x: x]
for i in range(n_dims):
xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
new_coeff_fns: List[Any] = [None] * (2 * len(coeff_fns))
new_coeff_fns[::2] = [
lambda x: torch.mul(coeff_fn(x), (1 - xscale)) for coeff_fn in coeff_fns
]
new_coeff_fns[1::2] = [
lambda x: torch.mul(coeff_fn(x), xscale) for coeff_fn in coeff_fns
]
coeff_fns = new_coeff_fns

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind filing a standalone issue here on the optimization inductor should be doing ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found #124653 while trying to create a MWE. Not sure if it's the same bug.

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 662c77dc817f26fd9a34758f21cf6941ecdb31d6
Pull Request resolved: #123222
[ghstack-poisoned]
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: 74b703f23ef08b60f3d05441b713fdfc2efc8ac8
Pull Request resolved: #123222
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Apr 26, 2024
ghstack-source-id: 9ef55f2915fa8891788c7ba3ff935c0fdb470562
Pull Request resolved: #123222
[ghstack-poisoned]
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 13, 2024
ghstack-source-id: c7c903588d5268d3ef77cd6041da7902fd64681e
Pull Request resolved: #123222
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 7ffa4f4508b05aad9f30fd6c8c1a95f76f404ad5
Pull Request resolved: #123222
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 14, 2024
ghstack-source-id: dda7b7d3c3bf3def6ff42d9b6d8c16518eab5e38
Pull Request resolved: #123222
[ghstack-poisoned]
isuruf added a commit that referenced this pull request May 15, 2024
ghstack-source-id: 2809dcfcdf34a8e117282f90a2834b4c121c0936
Pull Request resolved: #123222
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@isuruf
Copy link
Collaborator Author

isuruf commented Jun 25, 2024

I'll look at the triton PR as without it, this is worse than eager.

Benchmarks:

Decomposed (with triton PR) Decomposed (without triton PR) Eager
{'scaling': 2, 'shape': (3, 5, 10000000)} 1.140 1.255 1.569
{'scaling': 2, 'shape': (3, 5, 10000, 1000)} 1.787 2.273 2.747
{'scaling': 2, 'shape': (3, 5, 100, 1000, 100)} 4.576 6.399 5.073

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants