Skip to content

Commit

Permalink
Switch from arguments to option for the systems' translations.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed May 2, 2022
1 parent 40cc1d3 commit caf853c
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions comet/cli/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@
-h, --help Show this help message and exit.
-s SOURCES, --sources SOURCES
(required unless using -d, type: Path_fr)
-x SYSTEM_X, --system_x SYSTEM_X
(required, type: Path_fr)
-y SYSTEM_Y, --system_y SYSTEM_Y
(required, type: Path_fr)
-r REFERENCES, --references REFERENCES
(type: Path_fr, default: None)
-t TRANSLATIONS[, TRANSLATIONS], --translations TRANSLATIONS[, TRANSLATIONS]
(type: Path_fr, default: None)
-d SACREBLEU_TESTSET, --sacrebleu_dataset SACREBLEU_TESTSET
(optional, use in place of -s and -r, type: str
format TESTSET:LANGPAIR, e.g., wmt20:en-de)
Expand Down Expand Up @@ -244,10 +242,10 @@ def get_cfg() -> Namespace:
"""
Parse the CLI options and arguments.
"""
parser = ArgumentParser(description="Command for comparing multiple MT systems.")
parser = ArgumentParser(description="Command for comparing multiple MT systems' translations.")
parser.add_argument("-s", "--sources", type=Path_fr)
parser.add_argument("-r", "--references", type=Path_fr)
parser.add_argument("systems", nargs="*", type=Path_fr)
parser.add_argument("-t", "--translations", nargs="*", type=Path_fr)
parser.add_argument("-d", "--sacrebleu_dataset", type=str)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gpus", type=int, default=1)
Expand Down Expand Up @@ -380,13 +378,13 @@ def compare_command() -> None:
cfg = get_cfg()
seed_everything(cfg.seed_everything)

assert len(cfg.systems) > 1, "You must provide at least 2 systems"
assert len(cfg.translations) > 1, "You must provide at least 2 translation files"

with open(cfg.sources()) as fp:
sources = [line.strip() for line in fp.readlines()]

translations = []
for system in cfg.systems:
for system in cfg.translations:
with open(system, mode='r', encoding='UTF-8') as fp:
translations.append([line.strip() for line in fp.readlines()])

Expand All @@ -407,20 +405,20 @@ def compare_command() -> None:
num_splits=cfg.num_splits,
)

t_test_results = list(pairwise_t_test(sys_scores, cfg.systems))
t_test_results = list(pairwise_t_test(sys_scores, cfg.translations))
for data in t_test_results:
display_ttest_result(data)

info = {
"model": cfg.model,
"t_test": t_test_results,
"source": sources,
"systems": [
"translations": [
{
"name": name,
"mt": trans,
"scores": scores.tolist(),
} for name, trans, scores in zip(cfg.systems, translations, seg_scores)
} for name, trans, scores in zip(cfg.translations, translations, seg_scores)
],
}
if references is not None:
Expand Down

0 comments on commit caf853c

Please sign in to comment.