Skip to content

Commit

Permalink
bring hip graph back (tinygrad#2385)
Browse files Browse the repository at this point in the history
* bring hip graph back

* share with metal

* fix linter

* remove hasattrs

* Update ops_hip.py

* hip wrapper does not use _buf

---------

Co-authored-by: George Hotz <[email protected]>
  • Loading branch information
nimlgen and geohot committed Nov 24, 2023
1 parent 46b05da commit e68aebf
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 63 deletions.
50 changes: 16 additions & 34 deletions extra/hip_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,48 +177,30 @@ class kernelNodeParamsWrapper():
c_struct: Any
context: Any = None

# Better to cache struct_types since they reused often and take a lot of time to create.
struct_type_cache: Dict[str, Any] = {}
def __get_struct(name, field_list):
global struct_type_cache
if name in struct_type_cache:
return struct_type_cache[name]
def getCStructForType(argtypes):
fields = []
for j,typ in enumerate(argtypes):
fields.append((f'field{j}', typ))

class CStructure(ctypes.Structure):
_fields_ = field_list
struct_type_cache[name] = CStructure
return struct_type_cache[name]

def getStructTypeForArgs(*args):
types = ""
fields: List[Tuple[str, Any]] = []
for idx in range(len(args)):
if args[idx].__class__ is int:
types += 'i'
fields.append((f'field{idx}', ctypes.c_int))
else:
types += 'P'
fields.append((f'field{idx}', ctypes.c_void_p))
return __get_struct(types, fields)

def updateKernelNodeParams(npwrapper:kernelNodeParamsWrapper, *args, grid=(1,1,1), block=(1,1,1), updated_args=None):
_, struct, _ = npwrapper.context
if updated_args is not None:
for i in updated_args:
setattr(struct, f'field{i}', (args[i] if args[i].__class__ is int else args[i]._buf))
else:
for i,d in enumerate(args):
setattr(struct, f'field{i}', (d if d.__class__ is int else d._buf))
_fields_ = fields
return CStructure

def setKernelNodeLaunchDims(npwrapper:kernelNodeParamsWrapper, grid, block):
npwrapper.c_struct.blockDimX = block[0]
npwrapper.c_struct.blockDimY = block[1]
npwrapper.c_struct.blockDimZ = block[2]
npwrapper.c_struct.gridDimX = grid[0]
npwrapper.c_struct.gridDimY = grid[1]
npwrapper.c_struct.gridDimZ = grid[2]

def buildKernelNodeParams(*args, func=None, grid=(1,1,1), block=(1,1,1), sharedMemBytes=0, argsStructType=None):
data = [d if d.__class__ is int else d._buf for d in args]
if argsStructType is None: argsStructType = getStructTypeForArgs(*args)
struct = argsStructType(*data)
def setKernelNodeParams(npwrapper:kernelNodeParamsWrapper, args, ids):
for j,i in enumerate(ids):
setattr(npwrapper.context[1], f'field{i}', args[j])

def buildKernelNodeParams(args, argtypes, func, grid, block, sharedMemBytes=0):
c_struct_t = getCStructForType(argtypes)
struct = c_struct_t(*args)
size = ctypes.c_size_t(ctypes.sizeof(struct))
p_size = ctypes.c_void_p(ctypes.addressof(size))
p_struct = ctypes.c_void_p(ctypes.addressof(struct))
Expand Down
8 changes: 6 additions & 2 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv
from tinygrad.ops import RawBuffer, Device, JITRunner
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
from tinygrad.ops import RawBuffer, Device, JITRunner, CompiledASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node
Expand All @@ -24,6 +24,10 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]
input_replace[(j,i)] = input_rawbuffers.index(a)
assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))]
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]

class GraphException(Exception): pass

Expand Down
92 changes: 75 additions & 17 deletions tinygrad/runtime/ops_hip.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import ctypes
import extra.hip_wrapper as hip
from typing import Tuple
from typing import Tuple, List, Any, Dict, cast, Optional, Callable
from tinygrad.helpers import DEBUG, getenv, diskcache
from tinygrad.ops import Compiled
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
from tinygrad.renderer.hip import HIPRenderer
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer, RawBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException

# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
if DEBUG >= 6:
Expand Down Expand Up @@ -48,9 +50,22 @@ def compile_hip(prg) -> bytes:
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)

def time_execution(cb, enable=False):
if enable:
start, end = hip.hipEventCreate(), hip.hipEventCreate()
hip.hipEventRecord(start)
cb()
if enable:
hip.hipEventRecord(end)
hip.hipEventSynchronize(end)
ret = hip.hipEventElapsedTime(start, end)*1e-3
hip.hipEventDestroy(start)
hip.hipEventDestroy(end)
return ret

class HIPProgram:
def __init__(self, name:str, prg:bytes):
self.modules, self.prgs = [], []
self.modules, self.prgs, self.c_struct_t = [], [], None

if DEBUG >= 6:
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
Expand All @@ -63,20 +78,63 @@ def __init__(self, name:str, prg:bytes):

def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
hip.hipSetDevice(args[0]._device)
if wait:
start, end = hip.hipEventCreate(), hip.hipEventCreate()
hip.hipEventRecord(start)
struct = hip.getStructTypeForArgs(*args)(*[data._buf if not isinstance(data, int) else np.int32(data) for data in args])
hip.hipModuleLaunchKernel(self.prgs[args[0]._device], global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
if wait:
hip.hipEventRecord(end)
hip.hipEventSynchronize(end)
ret = hip.hipEventElapsedTime(start, end)*1e-3
hip.hipEventDestroy(start)
hip.hipEventDestroy(end)
return ret
if self.c_struct_t is None: self.c_struct_t = hip.getCStructForType([(ctypes.c_void_p if not isinstance(x, int) else ctypes.c_int) for x in args])
c_params = cast(Callable, self.c_struct_t)(*[x._buf if not isinstance(x, int) else x for x in args])
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prgs[args[0]._device], *global_size, *local_size, 0, 0, c_params), enable=wait)

def __del__(self):
for module in self.modules: hip.hipModuleUnload(module)

HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize)
class HIPGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
# TODO: Only HIPProgram can be captured for now.
if not all(isinstance(ji.prg, CompiledASTRunner) and isinstance(ji.prg.clprg, HIPProgram) for ji in jit_cache): raise GraphException

self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))

self.graph, graph_node = hip.hipGraphCreate(), None
self.updatable_nodes: Dict[int, Tuple[Any, hip.kernelNodeParamsWrapper]] = {} # Dict[jc index] = tuple(graph_node, node_params)

for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
assert all(x is not None for x in ji.rawbufs) and ji.rawbufs[0] is not None, "buffers could not be None" # for linters

args = [cast(RawBuffer, x)._buf for x in ji.rawbufs] + [var_vals[x] for x in prg.vars]
types = [ctypes.c_void_p] * len(ji.rawbufs) + [ctypes.c_int] * len(prg.vars)
c_params = hip.buildKernelNodeParams(args, types, prg.clprg.prgs[ji.rawbufs[0]._device], *prg.launch_dims(var_vals))
graph_node = hip.hipGraphAddKernelNode(self.graph, [graph_node] if graph_node else [], c_params)

if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
self.updatable_nodes[j] = (graph_node, c_params)

self.instance = hip.hipGraphInstantiate(self.graph)

def __del__(self):
hip.hipGraphExecDestroy(self.instance)
hip.hipGraphDestroy(self.graph)

def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Update cached params structs with the new values.
for (j,i),input_idx in self.input_replace.items():
hip.setKernelNodeParams(self.updatable_nodes[j][1], [input_rawbuffers[input_idx]._buf], [i])
for j in self.jc_idxs_with_updatable_launch_dims:
hip.setKernelNodeLaunchDims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
for j in self.jc_idxs_with_updatable_var_vals:
prg: CompiledASTRunner = cast(CompiledASTRunner, self.jit_cache[j].prg)
hip.setKernelNodeParams(self.updatable_nodes[j][1], [var_vals[x] for x in prg.vars], list(range(len(self.jit_cache[j].rawbufs), len(self.jit_cache[j].rawbufs) + len(prg.vars))))

# Update graph nodes with the updated structs.
for node, params in self.updatable_nodes.values():
hip.hipGraphExecKernelNodeSetParams(self.instance, node, params)

et = time_execution(lambda: hip.hipGraphLaunch(self.instance), enable=wait)
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et

HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, graph=HIPGraph)
17 changes: 7 additions & 10 deletions tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os, subprocess, pathlib, ctypes, tempfile
import Metal, libdispatch
from typing import List, Any, Tuple, Dict, Set, cast, Optional
from typing import List, Any, Tuple, Dict, cast, Optional
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, GraphException
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException

class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer],
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)

# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
Expand All @@ -99,7 +100,6 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer],
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")

self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
self.input_has_variable_dims: Set[int] = set()
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
Expand All @@ -117,11 +117,8 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer],
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = prg.launch_dims(var_vals)
assert prg.global_size and prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
self.input_has_variable_dims.add(j)
else:
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
Expand All @@ -134,7 +131,7 @@ def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, i
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.input_has_variable_dims:
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())
Expand Down

0 comments on commit e68aebf

Please sign in to comment.