Skip to content

Commit

Permalink
Merge pull request EleutherAI#117 from EleutherAI/pca-clean
Browse files Browse the repository at this point in the history
PCA init MVP
  • Loading branch information
AlexTMallen committed Mar 8, 2023
2 parents bb6b44c + 0aab4ed commit 6791633
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ReporterConfig(Serializable):
activation: Literal["gelu", "relu", "swish"] = "gelu"
bias: bool = True
hidden_size: Optional[int] = None
init: Literal["default", "spherical", "zero"] = "default"
init: Literal["default", "pca", "spherical", "zero"] = "default"
loss: list[str] = field(default_factory=lambda: ["ccs_prompt_var"])
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
Expand Down Expand Up @@ -219,10 +219,11 @@ def reset_parameters(self):
for layer in self.probe:
if isinstance(layer, nn.Linear):
layer.reset_parameters()

elif self.init == "zero":
for param in self.parameters():
param.data.zero_()
else:
elif self.init != "pca":
raise ValueError(f"Unknown init: {self.init}")

# TODO: These methods will do something fancier in the future
Expand Down Expand Up @@ -316,9 +317,15 @@ def fit(
best_state: dict[str, torch.Tensor] = {} # State dict of the best run
x0, x1 = contrast_pair

for _ in range(cfg.num_tries):
for i in range(cfg.num_tries):
self.reset_parameters()

# This is sort of inefficient but whatever
if self.init == "pca":
diffs = torch.flatten(x0 - x1, 0, 1)
_, __, V = torch.pca_lowrank(diffs, q=i + 1)
self.probe[0].weight.data = V[:, -1, None].T

if cfg.optimizer == "lbfgs":
loss = self.train_loop_lbfgs(x0, x1, labels, cfg)
elif cfg.optimizer == "adam":
Expand Down

0 comments on commit 6791633

Please sign in to comment.