Skip to content

Commit

Permalink
Torch/LLVM/arm F64 support (tinygrad#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
dc-dc-dc committed Aug 16, 2023
1 parent 913263c commit d17eccc
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 7 deletions.
6 changes: 6 additions & 0 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def test_casts_from_half(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.
def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)

@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends")
class TestDoubleDtype(unittest.TestCase):
def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4])
def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64)
def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64)

@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/codegen/assembly_arm64.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def specialize_to_arm64(fn_nm, asm):
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
Expand Down Expand Up @@ -137,12 +137,12 @@ def allocate_regs(mvars):
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
Expand Down
6 changes: 5 additions & 1 deletion tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def int_const(x): return ir.Constant(ir.IntType(64), x)
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}

dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}

def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
Expand All @@ -44,6 +44,8 @@ def cast(bb, val, input_type, output_type):
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
elif input_type == dtypes.float64:
val = bb[-1].fptrunc(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
Expand All @@ -55,6 +57,8 @@ def cast(bb, val, input_type, output_type):
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
elif output_type == dtypes.float64:
val = bb[-1].fpext(val, ir.DoubleType())
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Any
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
from tinygrad.helpers import prod, getenv, DEBUG, DType
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferMapped

Expand All @@ -23,6 +23,7 @@ def synchronize(self):

class RawMetalBuffer(RawBufferMapped):
def __init__(self, size:int, dtype:DType):
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
def __del__(self):
self._buf.release()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.runtime.lib import RawBuffer

device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
inverse_type_map = {v:k for k,v in type_map.items()}

torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_webgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __call__(self, global_size, local_size, *bufs, wait=False):

class RawWebGPUBuffer(RawBufferCopyIn):
def __init__(self, size:int, dtype:DType):
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU"
super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
Expand Down

0 comments on commit d17eccc

Please sign in to comment.