-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot_correlation.py
95 lines (92 loc) · 3.76 KB
/
plot_correlation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import sys
import csv
import json
import pandas
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
CHECKPOINTS = {0, 3200, 6400, 20000, 40000, 80000, 160000, 320000, 420000, 520000, 640000, 760000, 880000, 1000000}
ticks = []
data = []
chkpt_header = "checkpoint (in update steps)"
summary_root = "summaries"
scenario = "ALL (161GB)"
# def read_checkpoints():
# load LKT
probe_name = "LKT"
for test in os.listdir(f"{summary_root}/{probe_name}"):
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
tick = f"{probe_name}-{test}"
ticks.append(tick)
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
metrics = table['metric'][table[chkpt_header].isin(CHECKPOINTS)].to_list()
assert len(metrics) == len(CHECKPOINTS)
data.append(metrics)
# load BLiMP
probe_name = "BLiMP"
for test in os.listdir(f"{summary_root}/{probe_name}"):
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
tick = f"{probe_name}-{test}"
ticks.append(tick)
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
metrics = table['metric'][table[chkpt_header].isin(CHECKPOINTS)].to_list()
assert len(metrics) == len(CHECKPOINTS)
data.append(metrics)
# load LAMA
print()
probe_name = "LAMA"
K = 1
assert K in [1, 5, 10]
for test in os.listdir(f"{summary_root}/{probe_name}"):
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
table.loc[:, "metric"] = table.loc[:, "metric"].apply(lambda x: eval(x)[K])
sub_table = table[table[chkpt_header].isin(CHECKPOINTS)]
# sub_table.loc[:, "metric"] = sub_table.loc[:, "metric"].apply(lambda x: eval(x)[K])
for relation_name, relation_sub_table in sub_table.groupby("relation_type"):
ticks.append(f"{probe_name}-{test}-{relation_name}")
metrics = relation_sub_table['metric'].tolist()
data.append(metrics)
# load CAT
probe_name = "CAT"
for test in os.listdir(f"{summary_root}/{probe_name}"):
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
tick = f"{probe_name}-{test}"
ticks.append(tick)
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
metrics = table['metric'][table[chkpt_header].isin(CHECKPOINTS)].to_list()
assert len(metrics) == len(CHECKPOINTS)
data.append(metrics)
# load oLMpics
probe_name = "oLMpics"
for test in os.listdir(f"{summary_root}/{probe_name}"):
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
tick = f"{probe_name}-{test}"
ticks.append(tick)
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
metrics = table['metric'][table[chkpt_header].isin(CHECKPOINTS)].to_list()
assert len(metrics) == len(CHECKPOINTS)
data.append(metrics)
print()
# load Finetune
for test in ["CoLA", "MRPC", "SST-2", "WNLI", "WSC"]:
if not os.path.isdir(f"{summary_root}/{probe_name}/{test}"):
continue
tick = f"{probe_name}-{test}"
ticks.append(tick)
table = pandas.read_csv(f"{summary_root}/{probe_name}/{test}/{scenario}.csv", index_col=False)
table.sort_values(by=[chkpt_header], inplace=True)
metrics = table['metric'][table[chkpt_header].isin(CHECKPOINTS)].to_list()
assert len(metrics) == len(CHECKPOINTS)
data.append(metrics)