Skip to content

Commit

Permalink
update prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
hisashi-ito committed Apr 8, 2023
1 parent 6591eeb commit c586d9b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 19 deletions.
9 changes: 9 additions & 0 deletions input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
明日は晴れなので、
失礼しました、
大学の試験があったのですが、
週末の予定は、
高度な車両管理業務を実現する次世代型テレマティクスサービスブランド「LINKEETH」を
近年、世界的な潮流として企業に対する脱炭素および
株式会社NTTドコモ(以下、ドコモ)は、高性能ながら持ちやすいサイズ感の
AIによる色味の調整は従来の顔や背景だけでなく、
NTTグループでは、就業人口の急速な減少や高齢化、耕作放棄地の増加など、日本の農業の様々な課題をICTを活用して
6 changes: 3 additions & 3 deletions launch.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#! /bin/bash
IMAGE="gpt-neox"
CONTAINER="gpt-neox"
CONTAINER="gpt-neox3"
sudo docker run -tid \
--privileged \
--gpus all \
-v /data:/data \
-v /mnt/localdisk:/mnt/localdisk \
--shm-size=2000gb \
-v /var/data:/var/data \
--shm-size=128gb \
--network=host \
--name ${CONTAINER} \
${IMAGE} /bin/bash
Expand Down
38 changes: 22 additions & 16 deletions prompt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
# coding=utf-8

#
# usage: prompt.py <model_path>
#
import sys
import torch
from transformers import GPTNeoXForCausalLM
from megatron.tokenizer.tokenizer import HFTokenizer

model = GPTNeoXForCausalLM.from_pretrained("hf_model/save/location")
tokenizer = HFTokenizer("/data/gpt_neox/tokenizer/del_post_processor_tokenizer.json")

prompt = sys.argv[1]

input_ids = tokenizer.tokenize(prompt)
input_ids = torch.tensor([input_ids])
def main():
tokenizer = HFTokenizer("/data/gpt_neox/tokenizer/del_post_processor_tokenizer.json")
model_path = sys.argv[1]
model = GPTNeoXForCausalLM.from_pretrained(model_path)
with open("input.txt", mode="r", encoding="utf-8") as fin:
for line in fin:
text = line.rstrip("\n")
input_ids = tokenizer.tokenize(text)
input_ids = torch.tensor([input_ids])
gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=0.9,
max_length=256
)
hypo = tokenizer.detokenize(gen_tokens[0].tolist())
print(f"{model_path}\t{text}\t{hypo}")

gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=0.95,
max_length=64
)

text = tokenizer.detokenize(gen_tokens[0].tolist())
print(text)
if __name__ == '__main__':
main()

0 comments on commit c586d9b

Please sign in to comment.