-
Notifications
You must be signed in to change notification settings - Fork 2
/
build_coqa_value.py
90 lines (77 loc) · 2.87 KB
/
build_coqa_value.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import csv
import json
import os
import pickle as pkl
import sys
from glob import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
from .Dialects import (
AfricanAmericanVernacular,
AppalachianDialect,
ChicanoDialect,
ColloquialSingaporeDialect,
IndianDialect,
MultiDialect,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="data/CoQA")
parser.add_argument("--dialect", default="indian")
parser.add_argument(
"--lexical_mapping",
default="NONE",
help="a pickle file containing the lexical mapping from sae to the dialect",
)
parser.add_argument(
"--morphosyntax", action="store_true", help="set this flag to include morphosyntactic transformations"
)
parser.add_argument(
"--html",
action="store_true",
help="set this flag to add a column for HTML modification highlighting and tagging",
)
args = parser.parse_args()
mapping = {}
if os.path.exists(args.lexical_mapping):
with open(args.lexical_mapping, "rb") as infile:
mapping = pkl.load(infile)
dialect_choice = args.dialect.lower()
if dialect_choice == "aave":
D = AfricanAmericanVernacular(mapping, morphosyntax=args.morphosyntax)
elif dialect_choice == "indian":
D = IndianDialect(mapping, morphosyntax=args.morphosyntax)
elif dialect_choice == "singapore":
D = ColloquialSingaporeDialect(mapping, morphosyntax=args.morphosyntax)
elif dialect_choice == "chicano":
D = ChicanoDialect(mapping, morphosyntax=args.morphosyntax)
elif dialect_choice == "appalachian":
D = AppalachianDialect(mapping, morphosyntax=args.morphosyntax)
elif dialect_choice == "multi":
D = MultiDialect(mapping, morphosyntax=args.morphosyntax)
else:
print("Dialect {} Unimplemented".format(dialect))
sys.exit()
D.clear()
for fn in sorted(glob(os.path.join(args.input, "*.json"))):
bn = os.path.basename(fn)
print(fn, bn)
with open(fn, "r") as infile:
data = json.load(infile)
for datapoint in tqdm(data["data"]):
count = 0
for q in datapoint["questions"]:
q["sae_input_text"] = q["input_text"]
q["input_text"] = D.convert_sae_to_dialect(q["input_text"])
for a in datapoint["answers"]:
a["dialect_input_text"] = D.convert_sae_to_dialect(a["input_text"])
datapoint["dialect_story"] = D.convert_sae_to_dialect(datapoint["story"])
output = f"data/{D.dialect_code}_CoQA"
if not os.path.exists(output):
os.makedirs(output)
with open(os.path.join(output, bn), "w") as outfile:
json.dump(data, outfile)
if __name__ == "__main__":
main()