Skip to content

Commit

Permalink
torch and numpy don't share ops anymore (tinygrad#2412)
Browse files Browse the repository at this point in the history
* torch and numpy don't share ops anymore

* that should be filtered out elsewhere

* still const

* graph + enet example cleanup

* hmm, we do still need it because of symbolic
  • Loading branch information
geohot committed Nov 24, 2023
1 parent 193be14 commit 8f89e21
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 39 deletions.
18 changes: 7 additions & 11 deletions examples/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import sys
import io
import ast
import time
import cv2
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import fetch
from tinygrad.helpers import getenv, fetch, Timing
from tinygrad.jit import TinyJit
from extra.models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)
Expand Down Expand Up @@ -61,12 +58,12 @@ def infer(model, img):
model.load_from_pretrained()

# category labels
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())

# load image and preprocess
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
if url == 'webcam':
import cv2
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
while 1:
Expand All @@ -85,8 +82,7 @@ def infer(model, img):
cap.release()
cv2.destroyAllWindows()
else:
img = Image.open(io.BytesIO(fetch(url)))
st = time.time()
out, _ = infer(model, img)
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
print(f"did inference in {(time.time()-st):2f}")
img = Image.open(fetch(url))
with Timing("did inference in "):
out, _ = infer(model, img)
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def test_arange_simple(self):
def test_arange_big(self):
helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True)

def test_sum_fake(self):
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))

def test_sum_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True)

Expand Down
12 changes: 4 additions & 8 deletions tinygrad/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import os, atexit, functools
try:
import networkx as nx
except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
Expand All @@ -13,7 +9,6 @@

# **** debugging and graphing ****

G = nx.DiGraph() if nx is not None else None
cnts: Dict[OpType, int] = defaultdict(int)
if DEBUG >= 2:
def print_globalcounters():
Expand All @@ -22,6 +17,8 @@ def print_globalcounters():
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
atexit.register(print_globalcounters)
if GRAPH:
import networkx as nx
G = nx.DiGraph()
def save_graph_exit():
for k,v in cnts.items(): print(k, v)
print("saving", G)
Expand Down Expand Up @@ -61,16 +58,15 @@ def add_st_node(nmx, nmo, label, st:ShapeTracker):
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
def log_schedule_item(si: ScheduleItem):
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
if not DEBUG and not GRAPH: return
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return

op: List[Op] = [x.op for x in si.ast.get_lazyops()]
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if show_graph:
if GRAPH:
assert si.out.base == si.out, "all outputs based"
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)

Expand Down
3 changes: 1 addition & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def _interpret_ast(ast:LazyOp) -> str:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"

ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
Expand Down
15 changes: 5 additions & 10 deletions tinygrad/runtime/ops_cpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import operator
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
Expand All @@ -15,12 +14,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)

base_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x._buf, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
}

# TODO: this should be global infrastructure
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
def match_types(x, y):
Expand All @@ -37,17 +30,19 @@ def mulacc(a, b, new_shape):
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
return mulacc

numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x._buf, BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,
}}
}

CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
16 changes: 9 additions & 7 deletions tinygrad/runtime/ops_torch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import numpy as np
from typing import Dict, Callable, Optional
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
from tinygrad.runtime.lib import RawBuffer

device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
Expand All @@ -30,22 +30,24 @@ def as_strided(x, arg):
arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0])
return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2])

torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
# TODO: torch.tensor should work here
torch_fxn_for_op: Dict[Op, Callable] = {
# TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device),
BufferOps.MEM: lambda x: x._buf, BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
BinaryOps.ADD: lambda x,y: torch.add(*match_types(x, y)).type(output_type(x,y)),
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
BinaryOps.MUL: lambda x,y: torch.mul(*match_types(x, y)).type(output_type(x,y)),
BinaryOps.DIV: lambda x,y: torch.div(*match_types(x, y)).type(torch.promote_types(x.dtype, y.dtype)),
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)), lambda x: x.stride(), lambda x,s: x.expand(s)),
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
}}
}

TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)

0 comments on commit 8f89e21

Please sign in to comment.