Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* testing new memops

* better debugging

* testing padded conv

* branching with load

* refactoring a bit

* first try

* fixing bugs

* fixing some

* eq

* eq2

* do not use x's

* working

* fixing imm

* getting things working

* refactor

* pow not working

* working except one

* refactor: one store mem

* refactor: global load

* refactor: imm

* refactor: cleaning

* fixing big offsets

* refactor with ci

* try ci

* typo

* another typo

* ubuntu default

* forgot git

* do i need git?

* missing packages

* adding python-dev

* with cache?

* buildx action

* buildx name issue?

* maybe now?

* python3

* newline warning

* maybe now

* i actually need this

* ci should work now

* improved caching

* fixing cache

* maybe now it will cache

* this

* testing cache

* trying again

* load

* missing platform

* caching gha

* testing cache

* full testing

* typo

* now?

* why

* adding checkout back

* bad formatting

* fixing convention issues

* supporting python

* adding CI flag

* testing all

* better comments

* adding debugging

* takes 12x longer

* does it output progress now?

* ignore models for speed

* fixing merge

* excluding conv_transpose2d

* only 2 test cuz is to slow

* another approach

* let's see

* faster duh

* my bad

* T_T

* typo

* sup

* with output?

* comment test

* comment test

* comment test

* :?

* no comment

* with cache

* back to normal

* testing that ci works

* back to passing

* trying again

* does it create another entry

* does it create another entry?

* build local

* hey

* Revert "excluding conv_transpose2d"

This reverts commit cc7348d.

* does it cache if done before?

* does it cache?

* done

* adding test ops

* bad formatting

* no need for this

* working static mem

* sum 1d

* add ndim

* better reg import

* fix stack

* back to np

* working except for softmax

* 5 failing

* no pogress

* remove keystone

* remove keystone

* testops passing

* cleanups

* more cleanup

* typo

* ci

* ci2

* cond import

* ci3

* ci4

* ci4

* ci5

* ci5

* ci6

* aligment

* test all

* correct test

* err read_unmapped

* passing test

* ignore for speed

* ignore for speed

* ci7

* cleanup

* remove docker

* fixing merge

* fixing bugs

* add skipload for const ops

* comments

* First merge to master: Renderer

* fix emulation

* passing all tests arm64

* cleaning

* fix handcoded binary

* cleaning

* fix errs

* fix runtime arg binary

* clean git diff

* fix and clean

* fixing metal test

* cleaning

* fix metal test

* ci ~8 min

* fix pylint and clang

* cache the files in ops_clang

---------

Co-authored-by: George Hotz <[email protected]>
  • Loading branch information
stevenandersonz and geohot committed Aug 15, 2023
1 parent a89142e commit 93a36c3
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 152 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,28 @@ jobs:
- name: Run pytest (cuda)
if: matrix.backend=='cuda'
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models

testunicorn:
name: ARM64 unicorn Test
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Cache pip
uses: actions/cache@v3
with:
path: '~/.cache/pip'
key: unicorn
- name: Install cross-assembler
run: |
sudo apt-get update -y && \
sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu
- name: Install dependencies
run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test arm
run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py
282 changes: 144 additions & 138 deletions extra/assembly/assembly.py

Large diffs are not rendered by default.

171 changes: 171 additions & 0 deletions extra/assembly/assembly_arm64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import struct
from platform import system
from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register
from typing import Tuple, Set, Dict, List
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.linearizer import UOps, ConstOp, UOp
from tinygrad.helpers import dtypes, CI

def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096]*quotient + [remainder] if remainder else [4096]*quotient

#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name

class ARM64Language(AssemblyLanguage): pass

def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}

def mov_imm(value, reg):
# Manually move value into reg if value can't fit
if value.__class__ is not float and abs(value) > abs(65535):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append(f"str x15, [sp, 16]")
ins.append(f"ldr {reg}, [sp, 16]")
else:
ins.append(f"mov {reg}, #{value}")

# Get variables intervals
live_range:Dict[str, str] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]

mem_vars:Dict[str, str] = {}
rtor:Dict[str, str] = {}
def allocate_regs(vars):
nonlocal var_size
for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if len(available_regs) == 0:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()

temp_floats = ['s0', 's1', 's2']
temp_ints = ['x11', 'x12', 'x13']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")

if uop == UOps.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == UOps.CAST:
if arg == BinaryOps.CMPLT:
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == UOps.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i,k in enumerate(save_regs,1):
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
for i,k in enumerate(save_regs,1):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == UOps.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == UOps.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
ins.append(f"b.lt loop_{arg[1]}")

prev_uop=uop
# store regs into memory if needed
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])

def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
extras_require={
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'arm': ["unicorn"],
'triton': ["triton>=2.0.0.dev20221202"],
'webgpu': ["wgpu"],
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_test_speed_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tinygrad.runtime.lib import RawBuffer

class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __init__(self, name:str, prg:str, binary:bool): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass

class RawFakeBuffer(RawBuffer):
Expand Down
7 changes: 4 additions & 3 deletions test/test_uops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import unittest, math
import numpy as np
from tinygrad.helpers import dtypes
from tinygrad.helpers import dtypes, getenv
from tinygrad.tensor import Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
from tinygrad.shape.symbolic import Variable

def _uops_to_prg(uops):
src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime)
ret = Device[Device.DEFAULT].renderer("test", uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)

def _test_single_value(tc, tt, vals, op):
uops = [
Expand Down
6 changes: 4 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime,

def to_program(self, k):
k.linearize()
src, global_size, local_size = self.renderer(k.function_name, k.uops)
ret = self.renderer(k.function_name, k.uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
#TODO: I need to find a better way to select ARM64
return ASTRunner(k.function_name, src, global_size, local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name).build(self.runtime)
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)

def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!
Expand Down
61 changes: 54 additions & 7 deletions tinygrad/runtime/ops_clang.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
from functools import partial, reduce
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np

ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const

args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
Expand All @@ -11,24 +18,64 @@
}[platform.system()]

CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
ADDRESS = 0x10000

# Unicorn doesn't support external calls
def align(addr): return (addr+4095) & ~(4095)
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
def emulate_ext_calls(fn, uc, address, size, user_data):
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore

class ClangProgram:
def __init__(self, name:str, prg:str):
prg = CLANG_PROGRAM_HEADER + prg
def __init__(self, name:str, prg:str, binary:bool=False):
# TODO: is there a way to not write this to disk?
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
if not os.path.exists(fn):
tmp = f"{fn}.{os.getpid()}.tmp"
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
if not binary:
prg = CLANG_PROGRAM_HEADER + prg
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
else:
if DEBUG >= 5: print(prg)
if CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
self.prg = open(fn + '.bin', 'rb').read()
return
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]

def __call__(self, global_size, local_size, *args, wait=False):
if wait: st = time.monotonic()
self.fxn(*[x._buf for x in args])
if CI and ARM64:
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
mu.mem_map(ADDRESS, total_mem)
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
addr = ADDRESS + len(self.prg)
for i, arg in enumerate(args):
if i<=7:
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
else:
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
addr += arg.size * arg.dtype.itemsize
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf for x in args])
if wait: return time.monotonic()-st

renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def unwrap(x):
return ret

class MetalProgram:
def __init__(self, name:str, prg:str):
def __init__(self, name:str, prg:str, binary:bool=False):
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
Expand Down

0 comments on commit 93a36c3

Please sign in to comment.