Skip to content

Commit

Permalink
attention entropy and gini coefficient
Browse files Browse the repository at this point in the history
  • Loading branch information
jaydeepborkar committed Jun 27, 2023
1 parent c2ca2c4 commit b93c97b
Showing 1 changed file with 48 additions and 5 deletions.
53 changes: 48 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os



class PileDataset(Dataset):
"""
The wrapped around the Pile-derived pandas dataframe. This allows us to use the
Expand Down Expand Up @@ -178,7 +179,7 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
pile_dataset = PileDataset(pile_sequences, tokenizer)
batch_size = get_batch_size(split_name)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
Expand All @@ -198,9 +199,24 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
labels=tokenized_batch["input_ids"],
output_attentions=True,
)


save_inference_log(split_name, run_id, dataset, batch, labels, outputs, features)


def gini(array):
"""Calculate the Gini coefficient of a numpy array. Ref: https://github.com/oliviaguest/gini"""
# based on bottom eq: https://www.statsdirect.com/help/content/image/stat0206_wmf.gif
# from: https://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
array = array.flatten()
if np.amin(array) < 0:
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def save_inference_log(
split_name: str, run_id: str, dataset: pd.DataFrame, batch: tuple, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
):
Expand All @@ -219,8 +235,11 @@ def save_inference_log(
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
batch_sequence_ids = batch[0]
e=1e-8

for index, id_tensor in enumerate(batch_sequence_ids):
total_entropy = []
total_gini = []
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = outputs.loss.detach().item() / len(labels[index])
Expand All @@ -229,16 +248,40 @@ def save_inference_log(
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach().tolist()
inference_log[f"attn_{layer_index}"] = sequence_attention

for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
head_e = []
gini_head = []

for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
inference_log[f"head{head_index+1}_layer{layer_index+1}"] = attention_head
attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w))
gini_coefficient = gini(attention_head)
gini_head.append(gini_coefficient)
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_e.append(head_entropy)
inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient
inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy

avg_head = np.mean(head_e)
avg_head_gini = np.mean(gini_head)
total_entropy.append(avg_head)
total_gini.append(avg_head_gini)

average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini

inference_logs.append(inference_log)

file_name = split_name.replace(".", "_")
inference_logs_df = pd.DataFrame(inference_logs)
inference_logs_df.to_csv(f"datasets/{run_id}/{dataset}_{file_name}.csv", index=False, mode="a")

return inference_logs


def parse_cli_args():
parser = ArgumentParser()
Expand Down

0 comments on commit b93c97b

Please sign in to comment.