Skip to content

Commit

Permalink
update the usage of generate.py with a cli (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeiqinSun committed Apr 30, 2023
1 parent de8e7a8 commit aca32f6
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 56 deletions.
2 changes: 1 addition & 1 deletion large_language_models/alpaca-qlora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
- `python3 finetune_pp.py decapoda-research/llama-65b-hf /path/to/llama65b-pack8 --chunks 16 --pp_checkpoint except_last --micro_batch_size 32`

#### Inference
- `python3 generate.py`
- `python3 generate.py --load-qlora --llama-config /path/to/llama/config.json --qllama-checkpoint /path/to/quant-backbone-pack8 --qlora-dir /path/to/save/adapter --port 7860`

### training LLaMA-7b on single 2080ti
- the data of gpu-memory from nvidia-smi
Expand Down
131 changes: 76 additions & 55 deletions large_language_models/alpaca-qlora/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
import torch
from peft import PeftModel
import transformers
Expand All @@ -12,37 +13,30 @@
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

QUANT = True
PORT = 7860
CHECKPOINT_PATH = None
assert (
CHECKPOINT_PATH
), "set the checkpoint_path as the path to your output-dir where save your adapter_model.bin before running"

if QUANT:
model_cachedir = "./caches/llama-7b/"
config = transformers.AutoConfig.from_pretrained(
os.path.join(model_cachedir, "config.json")
)
model = load_qllama(
config, os.path.join(model_cachedir, "llama-7b_4w_pack8.pth.tar")
)
model = PeftQModel.from_pretrained(
model, CHECKPOINT_PATH, torch_dtype=torch.float16
)

else:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
)
def build_model(args):
if args.load_qlora:
config = transformers.AutoConfig.from_pretrained(args.llama_config)
model = load_qllama(config, args.qllama_checkpoint)
model = PeftQModel.from_pretrained(
model, args.qlora_dir, torch_dtype=torch.float16
)
if torch.cuda.is_available():
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
"tloen/alpaca-lora-7b",
torch_dtype=torch.float16,
device_map={"": 0},
)
return model


def generate_prompt(instruction, input=None):
Expand All @@ -65,11 +59,8 @@ def generate_prompt(instruction, input=None):
### Response:"""


model.cuda()
model.eval()


def evaluate(
tokenizer,
instruction,
input=None,
temperature=0.1,
Expand Down Expand Up @@ -101,24 +92,54 @@ def evaluate(
return output.split("### Response:")[1].strip()


gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2, label="Instruction", placeholder="Tell me about alpacas."
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
gr.components.Slider(minimum=0, maximum=4, step=1, value=4, label="Beams"),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🦙🌲 Alpaca-LoRA",
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
).launch(server_name="0.0.0.0", server_port=PORT, share=True)
if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--load-qlora", action="store_true", help="A flag indicates use quant lora"
)
parser.add_argument(
"--llama-config", type=str, help="the path to save llama config.json"
)
parser.add_argument(
"--qllama-checkpoint", type=str, help="A path to the quantized LLaMa backbone"
)
parser.add_argument(
"--qlora-dir",
type=str,
help="a path to the save dir includes adapter_config.json & adapter_model.bin",
)
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()

# all llama models use the same tokenizer
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

model = build_model(args)
model.eval()

evaluate_w_tokenizer = partial(evaluate, tokenizer)
gr.Interface(
fn=evaluate_w_tokenizer,
inputs=[
gr.components.Textbox(
lines=2, label="Instruction", placeholder="Tell me about alpacas."
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(minimum=0, maximum=4, step=1, value=4, label="Beams"),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🦙🌲 Alpaca-LoRA",
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
).launch(server_name="0.0.0.0", server_port=args.port, share=True)

0 comments on commit aca32f6

Please sign in to comment.