Skip to content

Commit

Permalink
Inline barrier (tinygrad#2255)
Browse files Browse the repository at this point in the history
* put barrier inline for locals

* fix pre-commit on m3

* gate if through barrier
  • Loading branch information
geohot committed Nov 10, 2023
1 parent 75f6e9a commit c0f447d
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
pass_filenames: false
- id: tests
name: subset of (CPU) tests
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system
always_run: true
pass_filenames: false
Expand Down
29 changes: 15 additions & 14 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }

def global_load(self, i:int, idxs:Sequence[Node], acc=None) -> List[UOp]:
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
buf = self.bufs[i]
const = buf.val if isinstance(buf, ConstBuffer) else acc

Expand Down Expand Up @@ -110,13 +110,13 @@ def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode)

if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret

def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
buf = self.bufs[i]
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
Expand All @@ -141,14 +141,16 @@ def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> None:
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
store_offset = store_offset_new

stores = []
for idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
else:
rendered_idx = idx.render(self.render_ops, self)
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
return stores

kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self):
Expand Down Expand Up @@ -230,7 +232,6 @@ def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
loaded_buffers = {}
acc = []
self.load_cache: Dict[str, UOp] = {}
if_gate: Optional[UOp] = None

# reduce op
fake_reduce_idxs: List[Variable] = []
Expand Down Expand Up @@ -321,13 +322,13 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
# end the local loop, do the local reduce
if self.group_for_reduce:
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False)
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
if self.opts.has_local:
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)

# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
Expand All @@ -352,7 +353,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
loop_ctx = render_loop(end_local_idxs)

# load localbufs
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)

# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
Expand All @@ -369,12 +370,9 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)

# end the if statement if we used it
if if_gate: self.uop(UOps.END, None, (if_gate,))

# (recursively) remove childless uops
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
Expand All @@ -396,13 +394,16 @@ def get_recursive_deps(x:UOp) -> List[UOp]:
return sorted(list(deps), key=lambda x: x.num)

# add END of loops after the last thing that (recursively) depends on them
# and END any if statements
for u in self.uops:
if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1])
at_end = self.uops[last_phi+1:]
self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end
elif u.uop == UOps.IF:
self.uop(UOps.END, None, (u,), cachable=False)

# maybe graph the uops
if DEBUG >= 5:
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i
def graph_uops(uops):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0"}
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
G = nx.DiGraph()
for u in uops:
if u.uop == UOps.END: continue
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(v.num, u.num)
GRAPHPATH = "/tmp/uops"
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def ssa(u, prefix="t"):
elif uop == UOps.LOAD:
assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, name:str, lib:bytes):
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
if DEBUG >= 5:
if DEBUG >= 6:
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
Expand Down

0 comments on commit c0f447d

Please sign in to comment.