-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/tmu-nlp/100knock2022 into k…
…yotaro
- Loading branch information
Showing
49 changed files
with
21,709 additions
and
2,508 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import re | ||
import sentencepiece as spm | ||
|
||
spm.SentencePieceTrainer.Train('--input=./Downloads/kftt-data-1.0/data/orig/kyoto-train.ja --model_prefix=kyoto_ja --vocab_size=16000 --character_coverage=1.0') | ||
sp = spm.SentencePieceProcessor() | ||
sp.Load('kyoto_ja.model') | ||
|
||
for src, dst in [ | ||
('./Downloads/kftt-data-1.0/data/orig/kyoto-train.ja', 'train.sub.ja'), | ||
('./Downloads/kftt-data-1.0/data/orig/kyoto-dev.ja', 'dev.sub.ja'), | ||
('./Downloads/kftt-data-1.0/data/orig/kyoto-test.ja', 'test.sub.ja'), | ||
]: | ||
with open(src) as f, open(dst, 'w') as g: | ||
for x in f: | ||
x = x.strip() | ||
x = re.sub(r'\s+', ' ', x) | ||
x = sp.encode_as_pieces(x) | ||
x = ' '.join(x) | ||
print(x, file=g) | ||
|
||
%%bash | ||
subword-nmt learn-bpe -s 16000 < ./Downloads/kftt-data-1.0/data/orig/kyoto-train.en > kyoto_en.codes | ||
subword-nmt apply-bpe -c kyoto_en.codes < ./Downloads/kftt-data-1.0/data/orig/kyoto-train.en > train.sub.en | ||
subword-nmt apply-bpe -c kyoto_en.codes < ./Downloads/kftt-data-1.0/data/orig/kyoto-dev.en > dev.sub.en | ||
subword-nmt apply-bpe -c kyoto_en.codes < ./Downloads/kftt-data-1.0/data/orig/kyoto-test.en > test.sub.en |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
fairseq-train data91 \ | ||
--fp16 \ | ||
--tensorboard-logdir log96 \ | ||
--save-dir save96 \ | ||
--max-epoch 5 \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --clip-norm 1.0 \ | ||
--lr 1e-3 --lr-scheduler inverse_sqrt --warmup-updates 2000 \ | ||
--dropout 0.2 --weight-decay 0.0001 \ | ||
--update-freq 1 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 8000 > 96.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
fairseq-train data91 \ | ||
--fp16 \ | ||
--save-dir save97_1 \ | ||
--max-epoch 10 \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --clip-norm 1.0 \ | ||
--lr 1e-3 --lr-scheduler inverse_sqrt --warmup-updates 2000 \ | ||
--dropout 0.1 --weight-decay 0.0001 \ | ||
--update-freq 1 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 8000 > 97_1.log | ||
|
||
fairseq-train data91 \ | ||
--fp16 \ | ||
--save-dir save97_2 \ | ||
--max-epoch 10 \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --clip-norm 1.0 \ | ||
--lr 1e-3 --lr-scheduler inverse_sqrt --warmup-updates 2000 \ | ||
--dropout 0.3 --weight-decay 0.0001 \ | ||
--update-freq 1 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 8000 > 97_2.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import tarfile | ||
|
||
with tarfile.open('./Downloads/en-ja.tar') as tar: | ||
for f in tar.getmembers(): | ||
if f.name.endswith('txt'): | ||
text = tar.extractfile(f).read().decode('utf-8') | ||
break | ||
|
||
data = text.splitlines() | ||
data = [x.split('\t') for x in data] | ||
data = [x for x in data if len(x) == 4] | ||
data = [[x[3], x[2]] for x in data] | ||
|
||
with open('./Downloads/jparacrawl.ja', 'w') as f, open('./Downloads/jparacrawl.en', 'w') as g: | ||
for j, e in data: | ||
print(j, file=f) | ||
print(e, file=g) | ||
|
||
with open('./Downloads/jparacrawl.ja') as f, open('./Downloads/train.jparacrawl.ja', 'w') as g: | ||
for x in f: | ||
x = x.strip() | ||
x = re.sub(r'\s+', ' ', x) | ||
x = sp.encode_as_pieces(x) | ||
x = ' '.join(x) | ||
print(x, file=g) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import subprocess | ||
import MeCab | ||
import time | ||
from bottle import route, run, template, request | ||
from datetime import datetime | ||
|
||
shell = 'CUDA_VISIBLE_DEVICES={$N} PYTHONIOENCODING=utf-8 fairseq-interactive data91 --path save91/checkpoint_best.pt --beam 1' | ||
|
||
@route('/translate') | ||
def output(): | ||
now = datetime.now() | ||
return template('knock99', text_inp='', text_res='') | ||
|
||
tagger = MeCab.Tagger('-Owakati') | ||
|
||
@route('/translate', method='POST') | ||
def translate(): | ||
proc = subprocess.Popen(shell, encoding='utf-8', stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True) | ||
input_text = request.forms.input_text | ||
input_ = tagger.parse(input_text) | ||
proc.stdin.write(input_) | ||
proc.stdin.close() | ||
res = proc.stdout.readlines()[-2].strip().split('\t')[-1] | ||
|
||
return template('knock99', text_inp=input_text, text_res=res) | ||
|
||
run(host='localhost', port=8080, debug=True) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#beam search | ||
GPU=$1 | ||
for N in `seq 1 20` ; do | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-interactive \ | ||
--path checkpoints/kftt.ja-en/checkpoint_best.pt \ | ||
--beam $N data/data-bin/kftt.ja-en/ \ | ||
< data/kftt-data-1.0/data/tok/kyoto-dev.ja | grep '^H' | cut -f3 > out94/beam.$N.out | ||
done | ||
|
||
for N in `seq 1 20` ; do | ||
echo beam=$N >> out94/score94.out | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-score --sys out94/beam.$N.out --ref data/kftt-data-1.0/data/tok/kyoto-dev.en >> out94/score94.out | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import sentencepiece as spm | ||
import re, os | ||
|
||
#日本語はsentencepiece | ||
spm.SentencePieceTrainer.Train("--input=./data/kftt-data-1.0/data/orig/kyoto-train.ja --model_prefix=kyoto_ja --vocab_size=16000 --character_coverage=0.9995") | ||
sp = spm.SentencePieceProcessor() | ||
sp.Load("kyoto_ja.model") | ||
|
||
for src, dst in [ | ||
("data/kftt-data-1.0/data/orig/kyoto-train.ja", "data/kftt-data-1.0/data/bpe/train.sub.ja"), | ||
("data/kftt-data-1.0/data/orig/kyoto-dev.ja", "data/kftt-data-1.0/data/bpe/dev.sub.ja"), | ||
("data/kftt-data-1.0/data/orig/kyoto-test.ja", "data/kftt-data-1.0/data/bpe/test.sub.ja"), | ||
]: | ||
with open(src) as f, open(dst, "w") as g: | ||
for x in f: | ||
x = x.strip() | ||
x = re.sub(r"\s+", " ", x)#全ての空白文字を置換 | ||
x = sp.encode_as_pieces(x)#返り値はリストっぽい | ||
x = " ".join(x)#空白で結合して文字列に | ||
print(x, file=g) | ||
|
||
#英語はsubword-nmt | ||
#os.system:unixコマンドがpythonで使える | ||
os.system("subword-nmt learn-bpe -s 16000 < data/kftt-data-1.0/data/orig/kyoto-train.en > kyoto_en.codes") | ||
for tar in ["train", "dev", "test"]: | ||
os.system(f"subword-nmt apply-bpe -c kyoto_en.codes < data/kftt-data-1.0/data/orig/kyoto-{tar}.en > data/kftt-data-1.0/data/bpe/{tar}.sub.en") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#前処理 | ||
GPU=$1 \ | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-preprocess --source-lang ja --target-lang en \ | ||
--trainpref data/kftt-data-1.0/data/bpe/train.sub \ | ||
--validpref data/kftt-data-1.0/data/bpe/dev.sub \ | ||
--testpref data/kftt-data-1.0/data/bpe/test.sub \ | ||
--destdir data/data-bin/kftt-bpe.ja-en/ \ | ||
--thresholdsrc 5 \ | ||
--thresholdtgt 5 \ | ||
--workers 20 | ||
|
||
#学習 | ||
GPU1=$1 \ | ||
GPU2=$2 \ | ||
CUDA_VISIBLE_DEVICES=$GPU1,$GPU2 fairseq-train data/data-bin/kftt-bpe.ja-en \ | ||
--save-dir checkpoints/kftt-bpe.ja-en/ \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ | ||
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ | ||
--dropout 0.3 --weight-decay 0.0001 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 4096 \ | ||
--max-epoch 10 | ||
|
||
#翻訳 | ||
GPU=$1 | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-interactive data/data-bin/kftt-bpe.ja-en \ | ||
--path checkpoints/kftt-bpe.ja-en/checkpoint_best.pt \ | ||
--remove-bpe \ | ||
< data/kftt-data-1.0/data/bpe/test.sub.ja \ | ||
| grep '^H' | cut -f3 > out/knock95.out | ||
|
||
#評価 | ||
fairseq-score --sys out/knock95.out --ref data/kftt-data-1.0/data/tok/kyoto-test.en | ||
|
||
""" | ||
サブワード前 | ||
Namespace(ignore_case=False, order=4, ref='data/kftt-data-1.0/data/tok/kyoto-test.en', sacrebleu=False, sentence_bleu=False, sys='knock92.out') | ||
BLEU4 = 5.34, 34.7/9.3/3.7/1.8 (BP=0.781, ratio=0.802, syslen=21432, reflen=26734) | ||
サブワード後 | ||
Namespace(ignore_case=False, order=4, ref='data/kftt-data-1.0/data/tok/kyoto-test.en', sacrebleu=False, sentence_bleu=False, sys='out/knock95.out') | ||
BLEU4 = 7.25, 28.2/10.7/4.5/2.0 (BP=1.000, ratio=1.091, syslen=29154, reflen=26734) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#学習 | ||
GPU1=$1 \ | ||
GPU2=$2 \ | ||
CUDA_VISIBLE_DEVICES=$GPU1,$GPU2 fairseq-train data/data-bin/kftt.ja-en \ | ||
--save-dir checkpoints96/kftt.ja-en/ \ | ||
--tensorboard-logdir log96 \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ | ||
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ | ||
--dropout 0.3 --weight-decay 0.0001 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 4096 \ | ||
--max-epoch 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#学習 | ||
GPU1=$1 | ||
GPU2=$2 | ||
for N in `seq 0.1 0.2 0.5`; do | ||
CUDA_VISIBLE_DEVICES=$GPU1,$GPU2 fairseq-train data/data-bin/kftt-bpe.ja-en \ | ||
--save-dir checkpoints/kftt-bpe.ja-en/dropout_$N \ | ||
--arch transformer --share-decoder-input-output-embed \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ | ||
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ | ||
--dropout $N --weight-decay 0.0001 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 4096 \ | ||
--max-epoch 5 | ||
done | ||
|
||
#翻訳 | ||
GPU=$1 | ||
for N in `seq 1 2 5` ; do | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-interactive data/data-bin/kftt-bpe.ja-en \ | ||
--path checkpoints/kftt-bpe.ja-en/dropout_$N/checkpoint_best.pt \ | ||
< data/kftt-data-1.0/data/bpe/dev.sub.ja | grep '^H' | cut -f3 > out/out96/dropout_$N.out | ||
done | ||
|
||
#評価 | ||
GPU=$1 | ||
for N in `seq 1 2 5` ; do | ||
echo beam=$N >> out/out97/score97.out | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-score --sys out/out97/dropout_$N.out --ref data/kftt-data-1.0/data/tok/kyoto-dev.en >> out/out97/score97.out | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
|
||
with open('./jesc/split/train') as train_f: | ||
with open('./jesc/split/train.ja', 'w') as ja_f, open('./jesc/split/train.en', 'w') as en_f: | ||
for line in train_f: | ||
en, ja = line.split('\t') | ||
ja_f.write(ja) | ||
en_f.write(en+'\n') | ||
|
||
os.system("mkdir data-bin/jesc.ja-en") | ||
os.system( | ||
"fairseq-preprocess --source-lang ja --target-lang en \ | ||
--trainpref jesc/split/train \ | ||
--validpref ../kftt-data-1.0/data/tok/kyoto-dev \ | ||
--testpref ../kftt-data-1.0/data/tok/kyoto-test \ | ||
--destdir data-bin/jesc.ja-en/ \ | ||
--bpe subword_nmt \ | ||
--thresholdsrc 5 \ | ||
--thresholdtgt 5 \ | ||
--workers 20" | ||
) | ||
|
||
os.system("mkdir checkpoints/jesc.ja-en/") | ||
os.system( | ||
"fairseq-train data-bin/jesc.ja-en \ | ||
--save-dir checkpoints/jesc.ja-en/ \ | ||
--arch lstm --share-decoder-input-output-embed \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ | ||
--lr 5e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ | ||
--dropout 0.3 --weight-decay 0.0001 \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--max-tokens 4096 \ | ||
--max-epoch 5 \ | ||
" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import streamlit as st | ||
import os | ||
import MeCab, ipadic | ||
|
||
#タイトルを決める | ||
st.title("knock99 Demo") | ||
|
||
#テキストのinput | ||
text = st.text_area("翻訳したいテキストを入力", max_chars=500) | ||
|
||
#ボタンを押すとモデルが翻訳開始 | ||
start_translation = st.button("翻訳") | ||
|
||
#入力を分かち書き | ||
wakati = MeCab.Tagger("-Owakati") | ||
result = wakati.parse(text) | ||
|
||
#分かち書き後の入力テキストをファイル出力 | ||
with open("input.txt", "w") as f: | ||
print(result, file=f) | ||
|
||
#翻訳してファイルに結果を出力 | ||
if start_translation: | ||
os.system( | ||
"fairseq-interactive data/data-bin/kftt-bpe.ja-en/ \ | ||
--path checkpoints/kftt-bpe.ja-en/checkpoint_best.pt \ | ||
< input.txt \ | ||
| grep '^H' | cut -f3 | sed -r 's/(@@ )|(@@ ?$)//g' > knock99.txt" | ||
) | ||
with open("knock99.txt", "r") as f:#いちいちファイル出力どうにかしたい | ||
output_list = [] | ||
for line in f: | ||
output_list.append(line) | ||
outputs = output_list[0] | ||
else: | ||
outputs = "" | ||
|
||
#翻訳結果を表示 | ||
st.text_area("翻訳結果", outputs) | ||
|
Oops, something went wrong.