Skip to content

Commit

Permalink
Update kernels.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leogao2 committed Jun 11, 2024
1 parent 0296e02 commit 3237a47
Showing 1 changed file with 287 additions and 0 deletions.
287 changes: 287 additions & 0 deletions sparse_autoencoder/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,290 @@ def backward(ctx, grad_output):
decoder_grad,
None,
)


def triton_add_mul_(
x: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
c: float,
):
"""
does
x += a * b * c
x : [m, n]
a : [m, n]
b : [m, n]
c : float
"""

if len(a.shape) == 1:
a = a[None, :].broadcast_to(x.shape)

if len(b.shape) == 1:
b = b[None, :].broadcast_to(x.shape)

assert x.shape == a.shape == b.shape

BLOCK_SIZE_M = 64
BLOCK_SIZE_N = 64
grid = lambda META: (
triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]),
triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]),
)
triton_add_mul_kernel[grid](
x,
a,
b,
c,
x.stride(0),
x.stride(1),
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
BLOCK_SIZE_M,
BLOCK_SIZE_N,
x.shape[0],
x.shape[1],
)


@triton.jit
def triton_add_mul_kernel(
x_ptr,
a_ptr,
b_ptr,
c,
stride_x0,
stride_x1,
stride_a0,
stride_a1,
stride_b0,
stride_b1,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

offsets_m = tl.arange(0, BLOCK_SIZE_M) + pid_m * BLOCK_SIZE_M
offsets_n = tl.arange(0, BLOCK_SIZE_N) + pid_n * BLOCK_SIZE_N

x = tl.load(
x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1,
mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
)
a = tl.load(
a_ptr + offsets_m[:, None] * stride_a0 + offsets_n[None, :] * stride_a1,
mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
)
b = tl.load(
b_ptr + offsets_m[:, None] * stride_b0 + offsets_n[None, :] * stride_b1,
mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
)

x_dtype = x.dtype
x = (x.to(tl.float32) + a.to(tl.float32) * b.to(tl.float32) * c).to(x_dtype)

tl.store(
x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1,
x,
mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
)



def triton_sum_dim0_in_fp32(xs):
a, b = xs.shape

assert xs.is_contiguous()
assert xs.dtype == torch.float16

BLOCK_SIZE_A = min(triton.next_power_of_2(a), 512)
BLOCK_SIZE_B = 64 # cache line is 128 bytes

out = torch.zeros(b, dtype=torch.float32, device=xs.device)

grid = lambda META: (triton.cdiv(b, META["BLOCK_SIZE_B"]),)

triton_sum_dim0_in_fp32_kernel[grid](
xs,
out,
stride_a=xs.stride(0),
a=a,
b=b,
BLOCK_SIZE_A=BLOCK_SIZE_A,
BLOCK_SIZE_B=BLOCK_SIZE_B,
)

return out


@triton.jit
def triton_sum_dim0_in_fp32_kernel(
xs_ptr,
out_ptr,
stride_a,
a,
b,
BLOCK_SIZE_A: tl.constexpr,
BLOCK_SIZE_B: tl.constexpr,
):
# each program handles 64 columns of xs
pid = tl.program_id(0)
offsets_b = tl.arange(0, BLOCK_SIZE_B) + pid * BLOCK_SIZE_B

all_out = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)

for i in range(0, a, BLOCK_SIZE_A):
offsets_a = tl.arange(0, BLOCK_SIZE_A) + i
xs = tl.load(
xs_ptr + offsets_a[:, None] * stride_a + offsets_b[None, :],
mask=(offsets_a < a)[:, None] & (offsets_b < b)[None, :],
other=0,
)
xs = xs.to(tl.float32)
out = tl.sum(xs, axis=0)
all_out += out

tl.store(out_ptr + offsets_b, all_out, mask=offsets_b < b)


def mse(
output,
target,
): # fusing fp32 cast and MSE to save memory
assert output.shape == target.shape
assert len(output.shape) == 2
assert output.stride(1) == 1
assert target.stride(1) == 1

a, b = output.shape

BLOCK_SIZE_B = triton.next_power_of_2(b)

class _MSE(torch.autograd.Function):
@staticmethod
def forward(ctx, output, target):
ctx.save_for_backward(output, target)
out = torch.zeros(a, dtype=torch.float32, device=output.device)

triton_mse_loss_fp16_kernel[(a,)](
output,
target,
out,
stride_a_output=output.stride(0),
stride_a_target=target.stride(0),
a=a,
b=b,
BLOCK_SIZE_B=BLOCK_SIZE_B,
)

return out

@staticmethod
def backward(ctx, grad_output):
output, target = ctx.saved_tensors
res = (output - target).float()
res *= grad_output[:, None] * 2 / b
return res, None

return _MSE.apply(output, target).mean()


def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
# only used for auxk
xs_mu = (
triton_sum_dim0_in_fp32(xs) / xs.shape[0]
if xs.dtype == torch.float16
else xs.mean(dim=0)
)

loss = mse(recon, xs) / mse(
xs_mu[None, :].broadcast_to(xs.shape), xs
)

return loss


@triton.jit
def triton_mse_loss_fp16_kernel(
output_ptr,
target_ptr,
out_ptr,
stride_a_output,
stride_a_target,
a,
b,
BLOCK_SIZE_B: tl.constexpr,
):
pid = tl.program_id(0)
offsets_b = tl.arange(0, BLOCK_SIZE_B)

output = tl.load(
output_ptr + pid * stride_a_output + offsets_b,
mask=offsets_b < b,
)
target = tl.load(
target_ptr + pid * stride_a_target + offsets_b,
mask=offsets_b < b,
)

output = output.to(tl.float32)
target = target.to(tl.float32)

mse = tl.sum((output - target) * (output - target)) / b

tl.store(out_ptr + pid, mse)


def triton_add_mul_(
x: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
c: float,
):
"""
does
x += a * b * c
x : [m, n]
a : [m, n]
b : [m, n]
c : float
"""

if len(a.shape) == 1:
a = a[None, :].broadcast_to(x.shape)

if len(b.shape) == 1:
b = b[None, :].broadcast_to(x.shape)

assert x.shape == a.shape == b.shape

BLOCK_SIZE_M = 64
BLOCK_SIZE_N = 64
grid = lambda META: (
triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]),
triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]),
)
triton_add_mul_kernel[grid](
x,
a,
b,
c,
x.stride(0),
x.stride(1),
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
BLOCK_SIZE_M,
BLOCK_SIZE_N,
x.shape[0],
x.shape[1],
)

0 comments on commit 3237a47

Please sign in to comment.