Skip to content

Commit

Permalink
move stuff in the linearizer (tinygrad#1726)
Browse files Browse the repository at this point in the history
* move stuff in linearizer

* move stuff in linearizer

* minor

* fix opts import
  • Loading branch information
geohot committed Aug 31, 2023
1 parent c18a497 commit 453e437
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 39 deletions.
3 changes: 2 additions & 1 deletion docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def linearize(self): pass
result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype)

# use the real Linearizer to linearize 2+3
from tinygrad.codegen.linearizer import Linearizer, LinearizerOptions
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions())
linearizer.linearize()

Expand Down
56 changes: 25 additions & 31 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
from tinygrad.codegen.optimizer import OptimizedKernel
from tinygrad.codegen.kernel import LocalBuffer, LinearizerOptions # noqa: F401 # pylint:disable=unused-import
from tinygrad.codegen.kernel import LocalBuffer
VariableOrNum = Union[Variable, NumNode, Node]

# bottom ones are asm only
Expand Down Expand Up @@ -205,56 +205,57 @@ def linearize(self):

# uops
self.uops: List[UOp] = []
self.load_cache: Dict[str, Token] = {}
self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict()

# add global buffers
for buf,name in self.arg_bufs.items():
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))

# add a local buffer for multistage reduce
# define local buffers
for lb in self.local_alias.values():
self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size()))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))

# define local buffers
for lb in self.local_alias.values():
self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size()))

# print
if DEBUG >= 3: self.printbufs()

# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])

# name the function something unique
Linearizer.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')

# define indexes
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)]
local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))]
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]

# global and local loops
self.uop(UOps.LOOP, None, [], (global_idxs, "global"))
self.uop(UOps.LOOP, None, [], (local_idxs, "local"))

# parse AST
loaded_buffers = {}
acc = []
self.load_cache: Dict[str, Token] = {}
self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict()

# ssa
_ssa:DefaultDict[str,int] = defaultdict(int)
def ssa(name, ltype=dtypes.float) -> Token:
_ssa[name] += 1
return Token(f"{name}{_ssa[name]-1}", ltype)

# global loop
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)]
self.uop(UOps.LOOP, None, [], (global_idxs, "global"))

# local loop
local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (local_idxs, "local"))

# upcast indexes
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]

# reduce op
fake_reduce_idxs = []
if self.reduceop is not None:
Expand All @@ -272,6 +273,7 @@ def ssa(name, ltype=dtypes.float) -> Token:
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())

# compute local aliases
# TODO: this is garbage code and should be at least moved elsewhere
locals_to_store = []
for i in self.local_alias:
strides = self.sts[i].real_strides()
Expand Down Expand Up @@ -381,17 +383,9 @@ def ssa(name, ltype=dtypes.float) -> Token:
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa)

if not self.group_for_reduce:
# end the global+local loop
self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local"))
else:
# end the global loop
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
# end the global (and maybe local) loop
self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local") if not self.group_for_reduce else (global_idxs, "global"))

# name the function something unique
Linearizer.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
return self

_OT = TypeVar("_OT")
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage

def pretty_ptx(s):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage

OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.helpers import DEBUG
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage

# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.ops import Compiled
from tinygrad.helpers import getenv, DEBUG
from ctypes import CFUNCTYPE
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir
from tinygrad.runtime.lib import RawMallocBuffer

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os, subprocess, pathlib, functools, ctypes
import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
from tinygrad.ops import Compiled
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_webgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import Compiled
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle
from tinygrad.renderer.wgsl import WGSLLanguage
import wgpu # type: ignore
Expand Down

0 comments on commit 453e437

Please sign in to comment.