Skip to content

Commit

Permalink
Amend self. references
Browse files Browse the repository at this point in the history
  • Loading branch information
mwmlo committed Aug 24, 2023
1 parent b2f21df commit 5c59cb1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/visualization/analyse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_tokens(self):
self.label2idx, self.idx2label, src2idx, idx2src = mapping

def train_probe(self):
self.probe = linear_probe.train_logistic_regression_probe(self.X, self.y)
self.probe = linear_probe.train_logistic_regression_probe(self.X, self.y, lambda_l1=0.00005, lambda_l2=0.00005)
# Evaluate probe metrics
scores = linear_probe.evaluate_probe(
self.probe, self.X, self.y, idx_to_class=self.idx2label
Expand All @@ -52,7 +52,7 @@ def identify_concept_neurons(self):
self.load_tokens()
if self.probe is None:
self.train_probe()
top_neurons, top_neurons_per_class = linear_probe.get_top_neurons(probe, 0.5, label2idx)
top_neurons, top_neurons_per_class = linear_probe.get_top_neurons(self.probe, 0.5, self.label2idx)
return top_neurons_per_class['SEM:named_entity:location']

def show_top_words(self, concept_neurons):
Expand Down

0 comments on commit 5c59cb1

Please sign in to comment.