Skip to content

Commit

Permalink
Merge pull request #2 from SALT-NLP/will/mt_eval
Browse files Browse the repository at this point in the history
Add CIs and Additional Dialects
  • Loading branch information
Helw150 committed Dec 15, 2022
2 parents 5712b93 + 5c2f45b commit 922617e
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions eval_mt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from datasets import load_dataset
from transformers.pipelines.pt_utils import KeyDataset
import numpy as np
from src.Dialects import (
AfricanAmericanVernacular,
IndianDialect,
ColloquialSingaporeDialect,
ChicanoDialect,
AppalachianDialect,
NigerianDialect,
BlackSouthAfricanDialect,
)
from sacrebleu.metrics import BLEU
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

TASK = "translation"
CKPT = "facebook/nllb-200-1.3B"
# CKPT = "facebook/nllb-200-distilled-600M"
CKPT = "facebook/nllb-200-distilled-1.3B"
src_lang = "eng_Latn"
tgt_lang_dict = {"de": "deu_Latn", "ru": "rus_Cyrl", "zh": "zho_Hans", "gu": "guj_Gujr"}
device = 0

import evaluate
device = 3 if "1.3B" in CKPT else 1


def dialect_factory(dialect):
Expand Down Expand Up @@ -49,7 +52,6 @@ def translate(examples):
return translate


sacrebleu = evaluate.load("sacrebleu")
model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to("cuda:" + str(device))
tokenizer = AutoTokenizer.from_pretrained(CKPT)
for lang in ["de", "gu", "zh", "ru"]:
Expand All @@ -63,13 +65,16 @@ def translate(examples):
max_length=400,
device=device,
)
sacrebleu = BLEU(trg_lang=lang)
for dialect in [
None,
AfricanAmericanVernacular,
IndianDialect,
ColloquialSingaporeDialect,
ChicanoDialect,
AppalachianDialect,
NigerianDialect,
BlackSouthAfricanDialect,
]:
d_dataset = dataset.map(flatten_factory(lang))
if dialect:
Expand All @@ -79,8 +84,11 @@ def translate(examples):
else:
dialect_name = "Standard American"
d_dataset = d_dataset.map(translate_factory(pipe), batched=True)
results = sacrebleu.compute(
predictions=d_dataset["tgt_pred"], references=d_dataset["tgt"]
rng = np.random.default_rng(12345)
res = sacrebleu.corpus_score(
list(d_dataset["tgt_pred"]),
[list(d_dataset["tgt"])],
n_bootstrap=1000,
)
print(f"{dialect_name} en -> {lang}")
print(results)
print(res.format().encode("latin-1", "replace").decode("latin-1"))

0 comments on commit 922617e

Please sign in to comment.