Skip to content

Commit

Permalink
Corrected file name
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Oct 30, 2023
1 parent a6d16b2 commit d2402e8
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions t-few/src/scripts/get_result_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from glob import glob
import json
from collections import defaultdict
from scipy.stats import iqr
from numpy import median, mean, std
import os
import argparse


def make_result_table(args):
def collect_exp_scores(exp_name_template, datasets):
print("=" * 80)
all_files = glob(
os.path.join(os.getenv("OUTPUT_PATH", default="exp_out"), exp_name_template, "dev_scores.json")
)
print(f"Find {len(all_files)} experiments fit into {exp_name_template}")

def read_last_eval(fname):
with open(fname) as f:
e = json.loads(f.readlines()[-1])
return e["AUC"]

acc_by_dataset = defaultdict(lambda: list())

def parse_expname(fname):
elements = fname.split("/")[-2].split("_")
return tuple(elements[:3] + ["_".join(elements[3:])])

for fname in all_files:
result = read_last_eval(fname)
model, dataset, seed, spec = parse_expname(fname)
acc_by_dataset[dataset].append(result)

def result_str(acc_list):
if len(acc_list) > 1:
return f"{mean(acc_list) * 100:.2f} ({std(acc_list) * 100:.2f})"
else:
return f"{acc_list[0] * 100:.2f}"

outputs = []
for dataset in datasets:
acc_list = acc_by_dataset[dataset]
outputs.append(result_str(acc_list))

print(", ".join([f"{dataset}: {value}" for dataset, value in zip(datasets, outputs)]))
return ",".join(outputs)

csv_lines = ["template," + (",".join(args.datasets))]
for exp_name_template in args.exp_name_templates:
csv_lines.append(f"{exp_name_template}," + collect_exp_scores(exp_name_template, args.datasets))

output_fname = os.path.join(os.getenv("OUTPUT_PATH", default="exp_out"), "summary.csv")
with open(output_fname, "w") as f:
for line in csv_lines:
f.write(line + "\n")
print(f"Save result to {output_fname}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name_templates", default="t03b_*_finetune", required=True)
parser.add_argument(
"-d", "--datasets", default="copa,h-swag,storycloze,winogrande,wsc,wic,rte,cb,anli-r1,anli-r2,anli-r3"
)
args = parser.parse_args()
args.exp_name_templates = args.exp_name_templates.split(",")
args.datasets = args.datasets.split(",")
make_result_table(args)

0 comments on commit d2402e8

Please sign in to comment.