-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decompositions for upsample linear backward #123222
base: gh/isuruf/37/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123222
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 349827a with merge base c888ee3 (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
ghstack-source-id: 8f8ea4ddbdf5d7d000d00448bc711265f343d549 Pull Request resolved: #123222
Need triton-lang/triton#3491 for performance |
need to remove the fallback for these. |
[ghstack-poisoned]
ghstack-source-id: 690147096db2c66c8c38b73c6b82e1592a92d7af Pull Request resolved: #123222
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 3e37a52f488305de4387e57c48332d6813698c24 Pull Request resolved: #123222
With the latest commit, here are some benchmarks. (I'm not sure what the appropriate values are for a benchmark here)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An idea: I think we can do all this just with one index_put_ and broadcasting.
There are a few ideas here:
- We have 2^n indices (and coefficients) of length d_i each to put together in one tensor.
- To do that, the idea is to put together n 1-dimensional tensors of length 2*d_i
The naïve way of doing this would be via torch.cat, but that doesn't fuse well in inductor. So, we are going to do it playing around with the indexing.
To do cat(arange(n), arange(n)
we can do arange(2n) % n
. With that we can get a x2_f32 = cat(x_f32, x_f32)
.
To get cat(x, xp1)
, we can do (x_f32.to(int64) + (arange(2n) // n))..clamp(max=inp_size - 1)
.
We do that for each dimension and broadcasting we get all the indices.
With pretty much the same tricks we can get the coefficients and then do one index_put.
Would be great to see if this can generate just one performant kernel!
nb. the same tricks apply to the forward of this op, if I'm not mistaken. Funnily enough, this just occurred to me. I didn't manage to write this in one kernel a year ago when I implemented decomps for a couple of these ops :D
Thanks for the suggestion. I tried that method earlier in the week. isuruf@ecef740. I checked that the triton code generated was one single kernel.
|
That's fair enough. Looking at the eager implementation, I see we end up doing it like your current approach, where each thread does 2^n atomic_adds. Mind posting the triton kernels with the current approach and the benchmarking script? |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 01df566cd9ede5e5ec05876513d92d4d04417163 Pull Request resolved: #123222
3D: from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_isuruf/sh/csh7njex4jcmtkekmbkknh5jednzydzaderybrvqvfdldqetdp57.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 18750000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_isuruf/26/c26lclpgh3vwdbju5ij5h7zyr3pgu42s246arjplhyyedtcatj2n.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 150000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = (xindex // 100000) % 100
x1 = (xindex // 100) % 1000
x0 = xindex % 100
x4 = xindex
x3 = (xindex // 10000000)
tmp18 = tl.load(in_ptr0 + (x4), xmask)
tmp0 = x2
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.494949494949495
tmp3 = tmp1 * tmp2
tmp4 = 0.0
tmp5 = triton_helpers.maximum(tmp3, tmp4)
tmp6 = tmp5.to(tl.int32)
tmp7 = x1
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.4994994994994995
tmp10 = tmp8 * tmp9
tmp11 = triton_helpers.maximum(tmp10, tmp4)
tmp12 = tmp11.to(tl.int32)
tmp13 = x0
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp14 * tmp2
tmp16 = triton_helpers.maximum(tmp15, tmp4)
tmp17 = tmp16.to(tl.int32)
tmp19 = tmp17.to(tl.float32)
tmp20 = tmp16 - tmp19
tmp21 = triton_helpers.maximum(tmp20, tmp4)
tmp22 = 1.0
tmp23 = triton_helpers.minimum(tmp21, tmp22)
tmp24 = tmp18 * tmp23
tmp25 = tmp24 * tmp23
tmp26 = tmp22 - tmp23
tmp27 = tmp25 * tmp26
tmp28 = tl.full([1], 1, tl.int64)
tmp29 = tmp17 + tmp28
tmp30 = tl.full([1], 49, tl.int64)
tmp31 = triton_helpers.minimum(tmp29, tmp30)
tmp32 = tmp25 * tmp23
tmp33 = tmp12 + tmp28
tmp34 = tl.full([1], 499, tl.int64)
tmp35 = triton_helpers.minimum(tmp33, tmp34)
tmp36 = tmp6 + tmp28
tmp37 = triton_helpers.minimum(tmp36, tmp30)
tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp12) + (25000*tmp6) + (1250000*x3)), tmp27, xmask)
tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp12) + (25000*tmp6) + (1250000*x3)), tmp32, xmask)
tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp35) + (25000*tmp6) + (1250000*x3)), tmp27, xmask)
tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp35) + (25000*tmp6) + (1250000*x3)), tmp32, xmask)
tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp12) + (25000*tmp37) + (1250000*x3)), tmp27, xmask)
tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp12) + (25000*tmp37) + (1250000*x3)), tmp32, xmask)
tl.atomic_add(out_ptr0 + (tmp17 + (50*tmp35) + (25000*tmp37) + (1250000*x3)), tmp27, xmask)
tl.atomic_add(out_ptr0 + (tmp31 + (50*tmp35) + (25000*tmp37) + (1250000*x3)), tmp32, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
args_1, = args
args.clear()
assert_size_stride(args_1, (3, 5, 100, 1000, 100), (50000000, 10000000, 100000, 100, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((3, 5, 50, 500, 50), (6250000, 1250000, 25000, 50, 1), torch.float32)
# Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_0.run(buf0, 18750000, grid=grid(18750000), stream=stream0)
# Source Nodes: [], Original ATen: []
triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
del args_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
args_1 = rand_strided((3, 5, 100, 1000, 100), (50000000, 10000000, 100000, 100, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([args_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module) 2D: from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_isuruf/by/cbyr7653xub6dabl6wfbowmcedms4u3auahw3nwzyho4ewtg57bj.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[67108864],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 37500000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_isuruf/4f/c4f7tg334eiuon7li3rdg4g6ciqrjsa7bq5ye7lwmoiqkop4onwg.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 150000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 1000) % 10000
x0 = xindex % 1000
x3 = xindex
x2 = (xindex // 10000000)
tmp13 = tl.load(in_ptr0 + (x3), xmask)
tmp0 = x1
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.49994999499949994
tmp3 = tmp1 * tmp2
tmp4 = 0.0
tmp5 = triton_helpers.maximum(tmp3, tmp4)
tmp6 = tmp5.to(tl.int32)
tmp7 = x0
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.4994994994994995
tmp10 = tmp8 * tmp9
tmp11 = triton_helpers.maximum(tmp10, tmp4)
tmp12 = tmp11.to(tl.int32)
tmp14 = tmp12.to(tl.float32)
tmp15 = tmp11 - tmp14
tmp16 = triton_helpers.maximum(tmp15, tmp4)
tmp17 = 1.0
tmp18 = triton_helpers.minimum(tmp16, tmp17)
tmp19 = tmp13 * tmp18
tmp20 = tmp17 - tmp18
tmp21 = tmp19 * tmp20
tmp22 = tl.full([1], 1, tl.int64)
tmp23 = tmp12 + tmp22
tmp24 = tl.full([1], 499, tl.int64)
tmp25 = triton_helpers.minimum(tmp23, tmp24)
tmp26 = tmp19 * tmp18
tmp27 = tmp6 + tmp22
tmp28 = tl.full([1], 4999, tl.int64)
tmp29 = triton_helpers.minimum(tmp27, tmp28)
tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (2500000*x2)), tmp21, xmask)
tl.atomic_add(out_ptr0 + (tmp25 + (500*tmp6) + (2500000*x2)), tmp26, xmask)
tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp29) + (2500000*x2)), tmp21, xmask)
tl.atomic_add(out_ptr0 + (tmp25 + (500*tmp29) + (2500000*x2)), tmp26, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
args_1, = args
args.clear()
assert_size_stride(args_1, (3, 5, 10000, 1000), (50000000, 10000000, 1000, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((3, 5, 5000, 500), (12500000, 2500000, 500, 1), torch.float32)
# Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_0.run(buf0, 37500000, grid=grid(37500000), stream=stream0)
# Source Nodes: [], Original ATen: []
triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
del args_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
args_1 = rand_strided((3, 5, 10000, 1000), (50000000, 10000000, 1000, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([args_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module) 1D: from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_isuruf/62/c62wpwuw5yo643cfkemlzsyfuzwdrxgqz4xoowq72dxdc2xht5ct.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[134217728],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 75000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_isuruf/cg/ccgowdbembfggavdbc3yb3w6ymffzvkkluje3w3vdlqlssfgobua.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277'},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 150000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 10000000
x2 = xindex
x1 = (xindex // 10000000)
tmp7 = tl.load(in_ptr0 + (x2), xmask)
tmp0 = x0
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.499999949999995
tmp3 = tmp1 * tmp2
tmp4 = 0.0
tmp5 = triton_helpers.maximum(tmp3, tmp4)
tmp6 = tmp5.to(tl.int32)
tmp8 = tmp6.to(tl.float32)
tmp9 = tmp5 - tmp8
tmp10 = triton_helpers.maximum(tmp9, tmp4)
tmp11 = 1.0
tmp12 = triton_helpers.minimum(tmp10, tmp11)
tmp13 = tmp11 - tmp12
tmp14 = tmp7 * tmp13
tmp15 = tl.full([1], 1, tl.int64)
tmp16 = tmp6 + tmp15
tmp17 = tl.full([1], 4999999, tl.int64)
tmp18 = triton_helpers.minimum(tmp16, tmp17)
tmp19 = tmp7 * tmp12
tl.atomic_add(out_ptr0 + (tmp6 + (5000000*x1)), tmp14, xmask)
tl.atomic_add(out_ptr0 + (tmp18 + (5000000*x1)), tmp19, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
args_1, = args
args.clear()
assert_size_stride(args_1, (3, 5, 10000000), (50000000, 10000000, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((3, 5, 5000000), (25000000, 5000000, 1), torch.float32)
# Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_0.run(buf0, 75000000, grid=grid(75000000), stream=stream0)
# Source Nodes: [], Original ATen: []
triton_poi_fused_1.run(args_1, buf0, 150000000, grid=grid(150000000), stream=stream0)
del args_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
args_1 = rand_strided((3, 5, 10000000), (50000000, 10000000, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([args_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module) Benchmark script: import torch
from torch.testing import make_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
from torch._inductor.compile_fx import compile_fx_inner, cudagraphify_impl
from torch._inductor.decomposition import decompositions
from itertools import product
from functools import partial
aten = torch.ops.aten
torch._logging.set_logs(output_code=True)
benchmark_name = "upsample_linear_backward"
Ss = [
[3, 5, 10000000, 2],
[3, 5, 10000, 1000, 2],
[3, 5, 100, 1000, 100, 2],
]
def gen_inputs():
for inp in Ss:
shape = inp[:-1]
yield [torch.randn(shape, dtype=torch.float32, device="cuda"), inp[-1:]]
def benchmark(label, f, x, kwargs):
kwargs = kwargs.copy()
kwargs["shape"] = tuple(x.shape)
return Timer(
"f([x,])",
globals=locals(),
label=benchmark_name,
description=label,
sub_label=str(kwargs),
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=4)
def compare(args):
x, props = args
print(f"{tuple(x.shape)}")
if x.dim() == 5:
aten_func = aten.upsample_trilinear3d_backward
elif x.dim() == 4:
aten_func = aten.upsample_bilinear2d_backward
else:
aten_func = aten.upsample_linear1d_backward
kwargs = {
"scaling": props[0],
}
inp_size = [*x.shape[:2], *[d//props[0] for d in x.shape[2:]]]
def f(args):
(x,) = args
val = aten_func(x, x.shape[2:], inp_size, align_corners=True)
return (val,)
args = [x]
decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(args)
compiled_decomposed = compile_fx_inner(decomposed, args, cudagraphs=False)
yield benchmark(f"Decomposed", compiled_decomposed, *args, kwargs)
# non_decomposed = make_fx(f, tracing_mode="fake")(args)
# compiled_nondecomposed = compile_fx_inner(non_decomposed, args, cudagraphs=False)
# yield benchmark("Lowering", compiled_nondecomposed, *args, kwargs)
# Just show the first two generated kernels
cuda_f = cudagraphify_impl(f, args, static_input_idxs=tuple(range(len(args))))
yield benchmark(f"Eager", cuda_f, *args, kwargs)
results = []
for x in gen_inputs():
for res in compare(x):
results.append(res)
compare = Compare(results)
compare.trim_significant_figures()
compare.print() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code and benchmarks look great! Quite interesting the trick you used to generate the kernel... But also sad that we need to do this manually and Inductor doesn't do it on its own.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 5d316aa9b5636181536effd913e4ff141ffd266d Pull Request resolved: #123222
@@ -2330,9 +2330,7 @@ def is_aligned(x): | |||
make_fallback(aten.fractional_max_pool3d_backward) | |||
make_fallback(aten.replication_pad1d_backward) | |||
make_fallback(aten.replication_pad2d_backward) | |||
make_fallback(aten.upsample_linear1d_backward) | |||
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll keep linear ones for this PR and do bicubic in a follow up PR
ghstack-source-id: 10fbbda017a05e5a3281b230f1a04b459c42dcd2 Pull Request resolved: #123222
torch/_decomp/decompositions.py
Outdated
# Using functions here to help inductor do fusions easier. | ||
# This results in (a * (b * (c * x)) instead of (a * (b * c)) * x | ||
# in the generated code and therefore inductor does not try to | ||
# create temporaries for (b * c). | ||
coeff_fns = [lambda x: x] | ||
for i in range(n_dims): | ||
xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) | ||
new_coeff_fns: List[Any] = [None] * (2 * len(coeff_fns)) | ||
new_coeff_fns[::2] = [ | ||
lambda x: torch.mul(coeff_fn(x), (1 - xscale)) for coeff_fn in coeff_fns | ||
] | ||
new_coeff_fns[1::2] = [ | ||
lambda x: torch.mul(coeff_fn(x), xscale) for coeff_fn in coeff_fns | ||
] | ||
coeff_fns = new_coeff_fns | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you mind filing a standalone issue here on the optimization inductor should be doing ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found #124653 while trying to create a MWE. Not sure if it's the same bug.
ghstack-source-id: 662c77dc817f26fd9a34758f21cf6941ecdb31d6 Pull Request resolved: #123222
ghstack-source-id: 74b703f23ef08b60f3d05441b713fdfc2efc8ac8 Pull Request resolved: #123222
ghstack-source-id: 9ef55f2915fa8891788c7ba3ff935c0fdb470562 Pull Request resolved: #123222
ghstack-source-id: c7c903588d5268d3ef77cd6041da7902fd64681e Pull Request resolved: #123222
ghstack-source-id: 7ffa4f4508b05aad9f30fd6c8c1a95f76f404ad5 Pull Request resolved: #123222
ghstack-source-id: dda7b7d3c3bf3def6ff42d9b6d8c16518eab5e38 Pull Request resolved: #123222
ghstack-source-id: 2809dcfcdf34a8e117282f90a2834b4c121c0936 Pull Request resolved: #123222
I'll look at the triton PR as without it, this is worse than eager. Benchmarks:
|
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang