Skip to content

Commit

Permalink
[ready] Replacing os with pathlib (tinygrad#1708)
Browse files Browse the repository at this point in the history
* replace os.path with pathlib

* safe convert dirnames to pathlib

* replace all os.path.join

* fix cuda error

* change main chunk

* Reviewer fixes

* fix vgg

* Fixed everything

* Final fixes

* ensure consistency

* Change all parent.parent... to parents
  • Loading branch information
crnsh committed Aug 30, 2023
1 parent 355b02d commit a8aa13d
Show file tree
Hide file tree
Showing 30 changed files with 89 additions and 86 deletions.
2 changes: 1 addition & 1 deletion disassemblers/adreno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def disasm(buf):
global fxn
if fxn is None:
shared = pathlib.Path(__file__).parent / "disasm.so"
if not os.path.isfile(shared):
if not shared.is_file():
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
fxn = ctypes.CDLL(shared.as_posix())['disasm']
#hexdump(buf)
Expand Down
8 changes: 5 additions & 3 deletions examples/compile_efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from pathlib import Path
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_save
from extra.utils import fetch
from extra.export_model import export_model
from tinygrad.helpers import getenv
import ast, os
import ast

if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
prg, inp_size, out_size, state = export_model(model, Tensor.randn(1,3,224,224), mode)
dirname = Path(__file__).parent
if getenv("CLANG", "") == "":
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
safe_save(state, (dirname / "net.safetensors").as_posix())
ext = "js" if getenv("WEBGPU", "") != "" else "json"
with open(os.path.join(os.path.dirname(__file__), f"net.{ext}"), "w") as text_file:
with open(dirname / f"net.{ext}", "w") as text_file:
text_file.write(prg)
else:
cprog = [prg]
Expand Down
4 changes: 2 additions & 2 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def convert(name) -> Tensor:
def load(fn:str):
if fn.endswith('.index.json'):
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
parts = {n: load(f'{os.path.dirname(fn)}/{os.path.basename(n)}') for n in set(weight_map.values())}
parts = {n: load(Path(fn).parent / Path(n).name) for n in set(weight_map.values())}
return {k: parts[n][k] for k, n in weight_map.items()}
elif fn.endswith('.safetensors'):
return safe_load(fn)
Expand Down Expand Up @@ -428,7 +428,7 @@ def greedy_until(self, prompt:str, until, max_length, temperature):


LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
MODEL_PATH = args.model or Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
Expand Down
2 changes: 1 addition & 1 deletion examples/mlperf/model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def run(input_ids, input_mask, segment_ids):
from examples.mlperf.metrics import f1_score
from transformers import BertTokenizer

tokenizer = BertTokenizer(str(Path(__file__).parent.parent.parent / "weights/bert_vocab.txt"))
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt"))

c = 0
f1 = 0.0
Expand Down
14 changes: 7 additions & 7 deletions examples/so_vits_svc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# original implementation: https://github.com/svc-develop-team/so-vits-svc
from __future__ import annotations
import sys, os, logging, time, io, math, argparse, operator, numpy as np
import sys, logging, time, io, math, argparse, operator, numpy as np
from functools import partial, reduce
from pathlib import Path
from typing import Tuple, Optional, Type
Expand Down Expand Up @@ -468,14 +468,14 @@ def repeat_expand_2d_left(content, target_len): # content : [h, t]
return Tensor.stack(cols).transpose(0, 1)

def load_fairseq_cfg(checkpoint_path):
assert os.path.isfile(checkpoint_path)
assert Path(checkpoint_path).is_file()
state = torch_load(checkpoint_path)
cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None
if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}")
return HParams(**cfg)

def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]):
assert os.path.isfile(checkpoint_path)
assert Path(checkpoint_path).is_file()
start_time = time.time()
checkpoint_dict = torch_load(checkpoint_path)
saved_state_dict = checkpoint_dict['model']
Expand Down Expand Up @@ -550,7 +550,7 @@ def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
# DEMO USAGE (uses audio sample from LJ-Speech):
# python3 examples/so_vits_svc.py --model saul_goodman
#########################################################################################
SO_VITS_SVC_PATH = Path(__file__).parent.parent / "weights/So-VITS-SVC"
SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC"
VITS_MODELS = { # config_path, weights_path, config_url, weights_url
"saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"),
"drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"),
Expand All @@ -563,13 +563,13 @@ def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
"contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt")
}
ENCODER_MODEL = "contentvec"
DEMO_PATH, DEMO_URL = Path(__file__).parent.parent / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
if __name__=="__main__":
logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG))
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True)
parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file")
parser.add_argument("--out_dir", default=str(Path(__file__).parent.parent / "temp"), help="Specify the output path.")
parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.")
Expand Down Expand Up @@ -600,7 +600,7 @@ def get_encoder(ssl_dim) -> Type[SpeechEncoder]:

### Loading audio and slicing ###
if audio_path == DEMO_PATH: download_if_not_present(DEMO_PATH, DEMO_URL)
assert os.path.isfile(audio_path) and Path(audio_path).suffix == ".wav"
assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav"
chunks = preprocess.cut(audio_path, db_thresh=slice_db)
audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks)

Expand Down
7 changes: 3 additions & 4 deletions examples/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# https://arxiv.org/pdf/2112.10752.pdf
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
import os
import tempfile
from pathlib import Path
import gzip, argparse, math, re
Expand Down Expand Up @@ -424,7 +423,7 @@ def __call__(self, input_ids):
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe():
fn = Path(__file__).parent.parent / "weights/bpe_simple_vocab_16e6.txt.gz"
fn = Path(__file__).parents[1] / "weights/bpe_simple_vocab_16e6.txt.gz"
download_file("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", fn)
return fn

Expand Down Expand Up @@ -558,13 +557,13 @@ def __init__(self):
# cond_stage_model.transformer.text_model

# this is sd-v1-4.ckpt
FILENAME = Path(__file__).parent.parent / "weights/sd-v1-4.ckpt"
FILENAME = Path(__file__).parents[1] / "weights/sd-v1-4.ckpt"

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--timing', action='store_true', help="Print timing per step")
Expand Down
8 changes: 3 additions & 5 deletions examples/vgg7.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
import os
import random
import json
import numpy
from pathlib import Path
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD
Expand Down Expand Up @@ -80,8 +80,7 @@ def load_and_save(path, save):

vgg7.load_waifu2x_json(json.load(open(src, "rb")))

if not os.path.isdir(model):
os.mkdir(model)
Path(model).mkdir(exist_ok=True)
load_and_save(model, True)
elif cmd == "execute":
model = sys.argv[2]
Expand All @@ -102,8 +101,7 @@ def load_and_save(path, save):
elif cmd == "new":
model = sys.argv[2]

if not os.path.isdir(model):
os.mkdir(model)
Path(model).mkdir(exist_ok=True)
load_and_save(model, True)
elif cmd == "train":
model = sys.argv[2]
Expand Down
6 changes: 3 additions & 3 deletions examples/vgg7_helpers/kinne.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tinygrad.tensor import Tensor
import numpy
import os
from pathlib import Path

# Format Details:
# A KINNE parameter set is stored as a set of files named "snoop_bin_*.bin",
Expand Down Expand Up @@ -35,8 +35,8 @@ def __init__(self, base: str, save: bool):
It is important that if you wish to save in the current directory,
you use ".", not the empty string.
"""
if save and not os.path.isdir(base):
os.mkdir(base)
if save:
Path(base).mkdir(exist_ok=True)
self.base = base + "/snoop_bin_"
self.next_part_index = 0
self.save = save
Expand Down
12 changes: 6 additions & 6 deletions examples/vits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json, logging, math, os, re, sys, time, wave, argparse, numpy as np
import json, logging, math, re, sys, time, wave, argparse, numpy as np
from functools import reduce
from pathlib import Path
from typing import List
Expand Down Expand Up @@ -522,7 +522,7 @@ def load_model(symbols, hps, model) -> Synthesizer:
_ = load_checkpoint(weights_path, net_g, None)
return net_g
def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_list=[]):
assert os.path.isfile(checkpoint_path)
assert Path(checkpoint_path).is_file()
start_time = time.time()
checkpoint_dict = torch_load(checkpoint_path)
iteration, learning_rate = checkpoint_dict['iteration'], checkpoint_dict['learning_rate']
Expand Down Expand Up @@ -556,8 +556,8 @@ def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_li
return model, optimizer, learning_rate, iteration

def download_if_not_present(file_path: Path, url: str):
if not os.path.isfile(file_path):
logging.info(f"Did not find {file_path}, downloading...")
if not file_path.is_file():
logging.info(f"Did not find {file_path.as_posix()}, downloading...")
download_file(url, file_path)
return file_path

Expand Down Expand Up @@ -649,7 +649,7 @@ def _expand_number(self, _inflect, m):
# anime lady 1 | --model_to_use uma_trilingual --speaker_id 36
# anime lady 2 | --model_to_use uma_trilingual --speaker_id 121
#########################################################################################
VITS_PATH = Path(__file__).parent.parent / "weights/VITS/"
VITS_PATH = Path(__file__).parents[1] / "weights/VITS/"
MODELS = { # config_path, weights_path, config_url, weights_url
"ljs": (VITS_PATH / "config_ljs.json", VITS_PATH / "pretrained_ljs.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
"vctk": (VITS_PATH / "config_vctk.json", VITS_PATH / "pretrained_vctk.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"),
Expand All @@ -665,7 +665,7 @@ def _expand_number(self, _inflect, m):
parser.add_argument("--model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
parser.add_argument("--speaker_id", type=int, default=6, help="Specify the speaker ID. Default is 6.")
parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
parser.add_argument("--out_dir", default=str(Path(__file__).parent.parent / "temp"), help="Specify the output path.")
parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
parser.add_argument("--text_to_synthesize", default="""Hello person. If the code you are contributing isn't some of the highest quality code you've written in your life, either put in the effort to make it great, or don't bother.""", help="Specify the text to synthesize. Default is a greeting message.")
parser.add_argument("--noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
Expand Down
2 changes: 1 addition & 1 deletion examples/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def prep_audio(waveform=None, sr=RATE) -> Tensor:
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
}

BASE = pathlib.Path(__file__).parent.parent / "weights"
BASE = pathlib.Path(__file__).parents[1] / "weights"
def get_encoding(n_vocab_in):
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", BASE / "gpt2.tiktoken")
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(BASE / "gpt2.tiktoken") if line)}
Expand Down
3 changes: 2 additions & 1 deletion examples/yolov8-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import os
from ultralytics import YOLO
import onnx
from pathlib import Path
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor

os.chdir("/tmp")
if not os.path.isfile("yolov8n-seg.onnx"):
if not Path("yolov8n-seg.onnx").is_file():
model = YOLO("yolov8n-seg.pt")
model.export(format="onnx", imgsz=[480,640])
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
Expand Down
10 changes: 4 additions & 6 deletions examples/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
import cv2
from collections import defaultdict
import os
import time, io, sys
from tinygrad.nn.state import safe_load, load_state_dict

Expand Down Expand Up @@ -398,13 +397,12 @@ def return_all_trainable_modules(self):
yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
print(f'running inference for YOLO version {yolo_variant}')

output_folder_path = './outputs_yolov8'
if not os.path.exists(output_folder_path):
os.makedirs(output_folder_path)
output_folder_path = Path('./outputs_yolov8')
output_folder_path.mkdir(parents=True, exist_ok=True)
#absolute image path or URL
image_location = [np.frombuffer(io.BytesIO(fetch(img_path)).read(), np.uint8)]
image = [cv2.imdecode(image_location[0], 1)]
out_paths = [os.path.join(output_folder_path, img_path.split("/")[-1].split('.')[0] + "_output" + '.' + img_path.split("/")[-1].split('.')[1])]
out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()]
if not isinstance(image[0], np.ndarray):
print('Error in image loading. Check your image file.')
sys.exit(1)
Expand All @@ -414,7 +412,7 @@ def return_all_trainable_modules(self):
depth, width, ratio = get_variant_multiples(yolo_variant)
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)

weights_location = Path(__file__).parent.parent / "weights" / f'yolov8{yolo_variant}.safetensors'
weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors'
download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location)

state_dict = safe_load(weights_location)
Expand Down
8 changes: 4 additions & 4 deletions extra/accel/ane/lib/ane.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import os
from pathlib import Path
from ctypes import *
import json
import collections
Expand All @@ -8,13 +8,13 @@
import struct
faulthandler.enable()

basedir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
basedir = Path(__file__).resolve().parent

libane = None
aneregs = None
def init_libane():
global libane, aneregs
libane = cdll.LoadLibrary(os.path.join(basedir, "libane.dylib"))
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())

libane.ANE_Compile.argtypes = [c_char_p, c_int]
libane.ANE_Compile.restype = c_void_p
Expand All @@ -29,7 +29,7 @@ def init_libane():

#libane.ANE_RegDebug.restype = c_char_p

with open(os.path.join(basedir, "aneregs.json")) as f:
with open(basedir / "aneregs.json") as f:
aneregs = json.load(f)

ANE_Struct = [
Expand Down
7 changes: 4 additions & 3 deletions extra/augment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from PIL import Image
import os
from pathlib import Path
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
cwd = Path.cwd()
sys.path.append(cwd.as_posix())
sys.path.append((cwd / 'test').as_posix())
from extra.datasets import fetch_mnist
from tqdm import trange

Expand Down
14 changes: 8 additions & 6 deletions extra/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os, random, gzip, tarfile, pickle
import random, gzip, tarfile, pickle
import numpy as np
from pathlib import Path
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from extra.utils import download_file

def fetch_mnist():
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
X_train = parse(os.path.dirname(__file__)+"/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(os.path.dirname(__file__)+"/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = parse(os.path.dirname(__file__)+"/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(os.path.dirname(__file__)+"/mnist/t10k-labels-idx1-ubyte.gz")[8:]
dirname = Path(__file__).parent.resolve()
X_train = parse(dirname / "mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(dirname / "mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = parse(dirname / "mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(dirname / "mnist/t10k-labels-idx1-ubyte.gz")[8:]
return X_train, Y_train, X_test, Y_test

cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
Expand All @@ -31,7 +33,7 @@ def _load_disk_tensor(sz, bs, db_list, path, shuffle=False):
Y[idx:idx+bs].assign(y[order])
idx += bs
return X, Y
fn = os.path.dirname(__file__)+"/cifar-10-python.tar.gz"
fn = Path(__file__).parent.resolve() / "cifar-10-python.tar.gz"
download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn)
tt = tarfile.open(fn, mode='r:gz')
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
Expand Down
2 changes: 1 addition & 1 deletion extra/datasets/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def iterate(tokenizer, start=0):
yield features, example

if __name__ == "__main__":
tokenizer = BertTokenizer(str(Path(__file__).parent.parent.parent / "weights" / "bert_vocab.txt"))
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))

X, Y = next(iterate(tokenizer))
print(" ".join(X[0]["tokens"]))
Expand Down
Loading

0 comments on commit a8aa13d

Please sign in to comment.