forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* testing new memops * better debugging * testing padded conv * branching with load * refactoring a bit * first try * fixing bugs * fixing some * eq * eq2 * do not use x's * working * fixing imm * getting things working * refactor * pow not working * working except one * refactor: one store mem * refactor: global load * refactor: imm * refactor: cleaning * fixing big offsets * refactor with ci * try ci * typo * another typo * ubuntu default * forgot git * do i need git? * missing packages * adding python-dev * with cache? * buildx action * buildx name issue? * maybe now? * python3 * newline warning * maybe now * i actually need this * ci should work now * improved caching * fixing cache * maybe now it will cache * this * testing cache * trying again * load * missing platform * caching gha * testing cache * full testing * typo * now? * why * adding checkout back * bad formatting * fixing convention issues * supporting python * adding CI flag * testing all * better comments * adding debugging * takes 12x longer * does it output progress now? * ignore models for speed * fixing merge * excluding conv_transpose2d * only 2 test cuz is to slow * another approach * let's see * faster duh * my bad * T_T * typo * sup * with output? * comment test * comment test * comment test * :? * no comment * with cache * back to normal * testing that ci works * back to passing * trying again * does it create another entry * does it create another entry? * build local * hey * Revert "excluding conv_transpose2d" This reverts commit cc7348d. * does it cache if done before? * does it cache? * done * adding test ops * bad formatting * no need for this * working static mem * sum 1d * add ndim * better reg import * fix stack * back to np * working except for softmax * 5 failing * no pogress * remove keystone * remove keystone * testops passing * cleanups * more cleanup * typo * ci * ci2 * cond import * ci3 * ci4 * ci4 * ci5 * ci5 * ci6 * aligment * test all * correct test * err read_unmapped * passing test * ignore for speed * ignore for speed * ci7 * cleanup * remove docker * fixing merge * fixing bugs * add skipload for const ops * comments * First merge to master: Renderer * fix emulation * passing all tests arm64 * cleaning * fix handcoded binary * cleaning * fix errs * fix runtime arg binary * clean git diff * fix and clean * fixing metal test * cleaning * fix metal test * ci ~8 min * fix pylint and clang * cache the files in ops_clang --------- Co-authored-by: George Hotz <[email protected]>
- Loading branch information
1 parent
a89142e
commit 93a36c3
Showing
9 changed files
with
405 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import struct | ||
from platform import system | ||
from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register | ||
from typing import Tuple, Set, Dict, List | ||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps | ||
from tinygrad.codegen.linearizer import UOps, ConstOp, UOp | ||
from tinygrad.helpers import dtypes, CI | ||
|
||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) | ||
def compute_offsets(total): | ||
quotient, remainder = divmod(total, 4096) | ||
return [4096]*quotient + [remainder] if remainder else [4096]*quotient | ||
|
||
#NOTE: Darwin needs names to start with a "_" | ||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name | ||
|
||
class ARM64Language(AssemblyLanguage): pass | ||
|
||
def specialize_to_arm64(fn_nm, asm): | ||
var_size = 16 | ||
prev_uop = None | ||
ins = [] | ||
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] | ||
s_regs = ['s' + str(i) for i in reversed(range(3,30))] | ||
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} | ||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", | ||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", | ||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), | ||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} | ||
|
||
def mov_imm(value, reg): | ||
# Manually move value into reg if value can't fit | ||
if value.__class__ is not float and abs(value) > abs(65535): | ||
ins.append(f"movz w15, #{value & 0xffff}") | ||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") | ||
ins.append(f"sxtw {reg}, w15") | ||
elif reg[0] == 's': | ||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") | ||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") | ||
ins.append(f"str x15, [sp, 16]") | ||
ins.append(f"ldr {reg}, [sp, 16]") | ||
else: | ||
ins.append(f"mov {reg}, #{value}") | ||
|
||
# Get variables intervals | ||
live_range:Dict[str, str] = {} | ||
for i, (uop, out, vin, arg) in enumerate(asm): | ||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): | ||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] | ||
|
||
mem_vars:Dict[str, str] = {} | ||
rtor:Dict[str, str] = {} | ||
def allocate_regs(vars): | ||
nonlocal var_size | ||
for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]: | ||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs | ||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem | ||
if len(available_regs) == 0: | ||
# ARM needs the stack 16-byte aligned | ||
var_size += 16 | ||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11') | ||
mem_vars[v.nm] = var_size | ||
rtor[v.nm] = available_regs.pop() | ||
|
||
temp_floats = ['s0', 's1', 's2'] | ||
temp_ints = ['x11', 'x12', 'x13'] | ||
for i, (uop, out, vin, arg) in enumerate(asm): | ||
# Clear regs out of interval | ||
for var, reg in list(rtor.items()): | ||
available_regs = s_regs if reg[0] == 's' else x_regs | ||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]: | ||
available_regs.append(rtor.pop(var)) | ||
# Assign a registers to the variables using live ranges. | ||
allocate_regs([out] + vin) | ||
# Assign temp regs to vin and load them before direct use | ||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): | ||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] | ||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 | ||
ins.append(f"mov x15, {mem_vars[v.nm]}") | ||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") | ||
|
||
if uop == UOps.SPECIAL: | ||
if arg.startswith('data'): | ||
# data 8 to n into the stack | ||
if int(arg[4:]) >= 8: | ||
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]") | ||
ins.append(f"mov {rtor[out.nm]}, x15") | ||
else: | ||
ins.append(f"mov {rtor[out.nm]}, #0") | ||
ins.append(f"loop_{arg}:") | ||
elif uop == UOps.CAST: | ||
if arg == BinaryOps.CMPLT: | ||
mov_imm(0.0, 's0') | ||
mov_imm(1.0, 's1') | ||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") | ||
else: | ||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") | ||
elif uop == UOps.ALU: | ||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') | ||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool: | ||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") | ||
elif arg == TernaryOps.WHERE: | ||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0") | ||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") | ||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: | ||
#NOTE: Not a real instruction, use to emulate a ext call in unicorn | ||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") | ||
else: | ||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] | ||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}") | ||
# Save the registers before they are cleared by func call | ||
for i,k in enumerate(save_regs,1): | ||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]") | ||
ins.append("stp x29, x30, [sp, #0]!") | ||
ins.append("mov x29, sp") | ||
ins.append(f"fmov s0, {rtor[vin[0].nm]}") | ||
ins.append(alu[arg]) | ||
ins.append(f"fmov {rtor[out.nm]}, s0") | ||
ins.append("mov sp, x29") | ||
ins.append("ldp x29, x30, [sp], #0") | ||
for i,k in enumerate(save_regs,1): | ||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") | ||
ins.append(f"add sp, sp, #{len(save_regs)*16}") | ||
elif arg == BinaryOps.CMPLT: | ||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") | ||
elif arg == BinaryOps.MOD: | ||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15") | ||
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}") | ||
else: | ||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") | ||
elif uop == UOps.LOAD: | ||
if arg.__class__ in (int, float): | ||
mov_imm(arg, rtor[out.nm]) | ||
else: | ||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs | ||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm] | ||
mov_imm(arg[0], "x15") | ||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15") | ||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") | ||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}") | ||
elif uop == UOps.STORE: | ||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"} | ||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs | ||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) | ||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}") | ||
ins.append(f"mov x15, #{arg[0]}") | ||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]") | ||
elif uop == UOps.COND_BRANCH: | ||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch? | ||
if prev_uop == UOps.LOAD: | ||
ins.append(f"cmp {rtor[vin[0].nm]}, #0") | ||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") | ||
elif uop == UOps.LABEL: | ||
ins.append(f"{arg[1:]}:") | ||
elif uop == UOps.ENDLOOP: | ||
mov_imm(arg[0], "x15") | ||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") | ||
ins.append(f"cmp {rtor[vin[0].nm]}, x15") | ||
ins.append(f"b.lt loop_{arg[1]}") | ||
|
||
prev_uop=uop | ||
# store regs into memory if needed | ||
if out is not None and out.nm in mem_vars: | ||
ins.append(f"mov x15, {mem_vars[out.nm]}") | ||
ins.append(f"str {rtor[out.nm]}, [sp, x15]") | ||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) | ||
|
||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: | ||
lang = ARM64Language() | ||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) | ||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters