Skip to content

Commit

Permalink
add name support to fetch (tinygrad#2407)
Browse files Browse the repository at this point in the history
* add name support

* use fetch in gpt2

* remove requests from main lib, networkx also optional

* umm, keep that assert

* updates to fetch

* i love the walrus so much

* stop bundling mnist with tinygrad

* err, https

* download cache names

* add DOWNLOAD_CACHE_VERSION

* need env.

* ugh, wrong path

* replace get_child
  • Loading branch information
geohot committed Nov 23, 2023
1 parent 397c093 commit 095e2ce
Show file tree
Hide file tree
Showing 16 changed files with 73 additions and 79 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
shell: bash
- name: Run LLaMA
run: |
JIT=0 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
shell: bash
- name: Run GPT2
run: |
Expand Down Expand Up @@ -121,8 +121,8 @@ jobs:
shell: bash
- name: Run LLaMA
run: |
JIT=0 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
shell: bash
- name: Run GPT2
run: |
Expand Down
18 changes: 13 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
DOWNLOAD_CACHE_VERSION: '1'

on:
push:
Expand Down Expand Up @@ -85,7 +88,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/tinygrad/downloads/
key: downloads-cache
key: downloads-cache-cpu-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
Expand Down Expand Up @@ -118,7 +121,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/tinygrad/downloads/
key: downloads-cache
key: downloads-cache-torch-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
Expand Down Expand Up @@ -155,6 +158,11 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache downloads
uses: actions/cache@v3
with:
path: ~/.cache/tinygrad/downloads/
key: downloads-cache-${{ matrix.task }}-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- if: ${{ matrix.task == 'optimage' }}
Expand Down Expand Up @@ -229,7 +237,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/Library/Caches/tinygrad/downloads/
key: downloads-cache
key: downloads-cache-metal-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Test LLaMA compile speed
run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
#- name: Run dtype test
Expand Down Expand Up @@ -293,8 +301,8 @@ jobs:
- name: Cache downloads
uses: actions/cache@v3
with:
path: ~/Library/Caches/tinygrad/downloads/
key: downloads-cache
path: ~/.cache/tinygrad/downloads/
key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
- name: Install OpenCL
Expand Down
8 changes: 3 additions & 5 deletions examples/compile_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from extra.models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_save
from extra.utils import fetch
from extra.export_model import export_model
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, fetch
import ast

if __name__ == "__main__":
Expand All @@ -21,11 +20,10 @@
else:
cprog = [prg]
# image library!
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8').replace("half", "_half")]
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]

# imagenet labels, move to datasets?
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
Expand Down
5 changes: 2 additions & 3 deletions examples/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict
from extra.utils import fetch_as_file
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch

MAX_CONTEXT = 128

Expand Down Expand Up @@ -106,7 +105,7 @@ def build(model_size="gpt2"):
tokenizer = tiktoken.get_encoding("gpt2")

model = Transformer(**MODEL_PARAMS[model_size])
weights = torch_load(fetch_as_file(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
# special treatment for the Conv1D weights we need to transpose
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
for k in weights.keys():
Expand Down
14 changes: 3 additions & 11 deletions examples/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit

Expand Down Expand Up @@ -405,10 +404,7 @@ def __call__(self, input_ids):

# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe():
fn = Path(__file__).parents[1] / "weights/bpe_simple_vocab_16e6.txt.gz"
download_file("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", fn)
return fn
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")

def get_pairs(word):
"""Return set of symbol pairs in a word.
Expand Down Expand Up @@ -576,9 +572,6 @@ def __call__(self, unconditional_context, context, latent, timestep, alphas, alp
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
# cond_stage_model.transformer.text_model

# this is sd-v1-4.ckpt
FILENAME = Path(__file__).parents[1] / "weights/sd-v1-4.ckpt"

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
Expand All @@ -595,8 +588,7 @@ def __call__(self, unconditional_context, context, latent, timestep, alphas, alp
model = StableDiffusion()

# load in weights
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)

if args.fp16:
for l in get_state_dict(model).values():
Expand Down
25 changes: 11 additions & 14 deletions examples/webgpu/stable_diffusion/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from extra.utils import download_file
from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
from pathlib import Path
import argparse
Expand All @@ -28,8 +28,6 @@ def convert_f32_to_f16(input_file, output_file):
front_float16_values.tofile(f)
rest_float32_values.tofile(f)

FILENAME = Path(__file__).parent.parent.parent.parent / "weights/sd-v1-4.ckpt"

def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
Expand All @@ -40,7 +38,7 @@ def split_safetensor(fn):
if (metadata[k]["data_offsets"][0] < text_model_offset):
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)

last_offset = 0
part_end_offsets = []

Expand All @@ -51,7 +49,7 @@ def split_safetensor(fn):
break

part_offset = offset - last_offset

if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
last_offset = offset
Expand All @@ -60,15 +58,15 @@ def split_safetensor(fn):
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
cur_pos = 0

for i, end_pos in enumerate(part_end_offsets):
with open(f'./net_part{i}.safetensors', "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos

with open(f'./net_textmodel.safetensors', "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])

return part_end_offsets

if __name__ == "__main__":
Expand All @@ -81,20 +79,19 @@ def split_safetensor(fn):
model = StableDiffusion()

# load in weights
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)

class Step(NamedTuple):
name: str = ""
input: List[Tensor] = []
forward: Any = None

sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
]

prg = ""

def compile_step(model, step: Step):
Expand All @@ -109,15 +106,15 @@ def compile_step(model, step: Step):
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor[0]);
{bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
Expand All @@ -140,7 +137,7 @@ def compile_step(model, step: Step):
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
}}
}}
}}
"""
Expand Down
11 changes: 5 additions & 6 deletions extra/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os, gzip, tarfile, pickle
import numpy as np
from pathlib import Path
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, fetch

def fetch_mnist(tensors=False):
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
dirname = Path(__file__).parent.resolve()
X_train = parse(dirname / "mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(dirname / "mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = parse(dirname / "mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(dirname / "mnist/t10k-labels-idx1-ubyte.gz")[8:]
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http:https://yann.lecun.com/exdb/mnist/ lacks https
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
else: return X_train, Y_train, X_test, Y_test

Expand Down
Binary file removed extra/datasets/mnist/t10k-images-idx3-ubyte.gz
Binary file not shown.
Binary file removed extra/datasets/mnist/t10k-labels-idx1-ubyte.gz
Binary file not shown.
Binary file removed extra/datasets/mnist/train-images-idx3-ubyte.gz
Binary file not shown.
Binary file removed extra/datasets/mnist/train-labels-idx1-ubyte.gz
Binary file not shown.
7 changes: 3 additions & 4 deletions extra/models/efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d
from extra.utils import get_child
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load

class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
Expand Down Expand Up @@ -142,9 +143,7 @@ def load_from_pretrained(self):
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
}

from extra.utils import fetch_as_file
from tinygrad.nn.state import torch_load
b0 = torch_load(fetch_as_file(model_urls[self.number]))
b0 = torch_load(fetch(model_urls[self.number]))
for k,v in b0.items():
if k.endswith("num_batches_tracked"): continue
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
Expand Down
13 changes: 2 additions & 11 deletions extra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from typing import Union

from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.helpers import prod, getenv, DEBUG, dtypes, get_child
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
Expand Down Expand Up @@ -47,13 +47,4 @@ def download_file(url, fp, skip_if_exists=True):
f.close()
Path(f.name).rename(fp)

def get_child(parent, key):
obj = parent
for k in key.split('.'):
if k.isnumeric():
obj = obj[int(k)]
elif isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj

4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
],
install_requires=["numpy", "requests", "tqdm", "networkx", "pyopencl",
install_requires=["numpy", "tqdm", "pyopencl",
"pyobjc-framework-Metal; platform_system=='Darwin'",
"pyobjc-framework-Cocoa; platform_system=='Darwin'",
"pyobjc-framework-libdispatch; platform_system=='Darwin'"],
Expand Down Expand Up @@ -55,6 +55,8 @@
"sentencepiece",
"tiktoken",
"librosa",
"requests",
"networkx",
]
},
include_package_data=True)
6 changes: 3 additions & 3 deletions test/unit/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ def test_round_up(self):

class TestFetch(unittest.TestCase):
def test_fetch_bad_http(self):
self.assertRaises(AssertionError, fetch, 'http:https://www.google.com/404')
self.assertRaises(Exception, fetch, 'http:https://www.google.com/404')

def test_fetch_small(self):
assert(len(fetch('https://google.com').read_bytes())>0)
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)

def test_fetch_img(self):
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190", allow_caching=False)
with Image.open(img) as pimg:
assert pimg.size == (705, 1024)

Expand Down
Loading

0 comments on commit 095e2ce

Please sign in to comment.