forked from yonsei-sslab/Language_Model_Memorization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.py
129 lines (105 loc) · 4.17 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import zlib
import torch
import numpy as np
from enum import Enum
def calculate_individual_perplexity(input_ids, model):
""" perplexity defined as the exponential of the model's loss """
model.eval()
with torch.no_grad():
output = model(input_ids, labels=input_ids)
perplexity = torch.exp(output.loss)
del output, input_ids
return float(perplexity.cpu().detach().numpy())
def calculate_individual_lower(input_ids, model, tokenizer, device):
# if input_ids is nested sequence, lower the dimension
if len(input_ids.size()) != 1:
input_ids = input_ids.squeeze()
text = "".join(tokenizer.decode(input_ids.cpu().detach().numpy()))
text = text.lower()
input_ids = tokenizer.encode(text, return_tensors="pt")
input_ids = input_ids.to(device)
return calculate_individual_perplexity(input_ids, model)
def calculate_individual_window(input_ids, model, window_size=50):
"""
Sometimes a model is not confident when the sample
contains one memorized substring surrounded by a
block of non-memorized (and high perplexity) text.
To handle this, we use the minimum perplexity when
averaged over a sliding window of 50 tokens.
"""
model.eval()
# if input_ids is nested sequence, lower the dimension
if len(input_ids.size()) != 1:
input_ids = input_ids.squeeze()
# if not sliding window unavailable, then return mere perplexity
if input_ids.size(0) < window_size:
return calculate_individual_perplexity(input_ids, model)
# make tensors for the sliding window
sliding_windows = input_ids.unfold(0, window_size, 1)
min_perplexity = np.inf
# yield the lowest perplexity score out of given sliding window
with torch.no_grad():
for tensor in sliding_windows:
perplexity = calculate_individual_perplexity(tensor, model)
del tensor
min_perplexity = min(min_perplexity, perplexity)
del input_ids
return min_perplexity
def calculate_individual_zlib(input_ids, tokenizer):
"""
As a simple baseline method, we compute the zlib entropy of the text:
the number of bits of entropy when the sequence is compressed with zlib compression.
Although text compressors are simple, they can identify many of the
examples of trivial memorization and repeated patterns described above
(e.g., they are excellent at modeling repeated substrings).
"""
# if input_ids is nested sequence, lower the dimension
if len(input_ids.size()) != 1:
input_ids = input_ids.squeeze()
text = "".join(tokenizer.decode(input_ids.cpu().detach().numpy()))
text = text.lower()
return float(len(zlib.compress(bytes(text, "utf-8"))))
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class Metric(object):
"""Computes and stores the average accross the given batches"""
def __init__(self, name, fmt=":6.3f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.collected = np.array([])
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.collected = (
np.concatenate(self.collected, self.val)
if type(self.val) == np.ndarray or type(self.val) == list
else np.append(self.collected, self.val)
)
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ""
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.3f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.3f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.3f}"
else:
raise ValueError("invalid summary type %r" % self.summary_type)
return fmtstr.format(**self.__dict__)