-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make generation script into a notebook, makes much more sense that wa…
…y i think, and much easier to use
- Loading branch information
Showing
3 changed files
with
163 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", | ||
"from mingpt.model import GPT\n", | ||
"from mingpt.utils import sample\n", | ||
"from mingpt.utils import set_seed\n", | ||
"set_seed(3407)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"use_mingpt = True # use minGPT or huggingface/transformers model?\n", | ||
"model_type = 'gpt2-xl'\n", | ||
"device = 'cuda'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"number of parameters: 1557.61M\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"if use_mingpt:\n", | ||
" model = GPT.from_pretrained(model_type)\n", | ||
"else:\n", | ||
" model = GPT2LMHeadModel.from_pretrained(model_type)\n", | ||
" model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n", | ||
"\n", | ||
"# ship model to device and set to eval mode\n", | ||
"model.to(device)\n", | ||
"model.eval();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n", | ||
" \n", | ||
" # tokenize the input prompt into integer input sequence\n", | ||
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n", | ||
" if prompt == '': # to create unconditional samples we feed in the special start token\n", | ||
" prompt = '<|endoftext|>'\n", | ||
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n", | ||
" x = encoded_input['input_ids']\n", | ||
" x = x.expand(num_samples, -1)\n", | ||
"\n", | ||
" # forward the model `steps` times to get samples, in a batch\n", | ||
" if use_mingpt:\n", | ||
" y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n", | ||
" else:\n", | ||
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", | ||
" \n", | ||
" for i in range(num_samples):\n", | ||
" out = tokenizer.decode(y[i].cpu().squeeze())\n", | ||
" print('-'*80)\n", | ||
" print(out)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", | ||
"2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.10.4 64-bit", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file was deleted.
Oops, something went wrong.