Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix argument passthrough for sweep #266

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Revert "change variable name"
This reverts commit f2cdbb1.
  • Loading branch information
derpyplops committed Jul 12, 2023
commit 025acb1d5bed83c51ccc1ffb6fea0fb45e418f8a
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __post_init__(self):
# Set our output directory before super().execute() does
if not self.out_dir:
root = elk_reporter_dir() / self.source
self.out_dir = root / "transfer" / "+".join(self.extract.datasets)
self.out_dir = root / "transfer" / "+".join(self.data.datasets)

def execute(self, highlight_color: Color = "cyan"):
return super().execute(highlight_color, split_type="val")
Expand Down
8 changes: 4 additions & 4 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@dataclass
class Run(ABC, Serializable):
extract: Extract
data: Extract
out_dir: Path | None = None
"""Directory to save results to. If None, a directory will be created
automatically."""
Expand Down Expand Up @@ -67,14 +67,14 @@ def execute(
min_gpu_mem=self.min_gpu_mem,
split_type=split_type,
)
for cfg in self.extract.explode()
for cfg in self.data.explode()
]

if self.out_dir is None:
# Save in a memorably-named directory inside of
# ELK_REPORTER_DIR/<model_name>/<dataset_name>
ds_name = "+".join(self.extract.datasets)
root = elk_reporter_dir() / self.extract.model / ds_name
ds_name = "+".join(self.data.datasets)
root = elk_reporter_dir() / self.data.model / ds_name

self.out_dir = memorably_named_dir(root)

Expand Down
4 changes: 2 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def reset_parameters(self):
theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device)
theta /= theta.norm()
probe.weight.data = theta[:, :-1]
probe.bias.extract = theta[:, -1]
probe.bias.data = theta[:, -1]

elif self.config.init == "default":
for layer in self.probe:
Expand Down Expand Up @@ -219,7 +219,7 @@ def fit(self, hiddens: Tensor) -> float:
if self.config.init == "pca":
diffs = torch.flatten(x_pos - x_neg, 0, 1)
_, __, V = torch.pca_lowrank(diffs, q=i + 1)
self.probe[0].weight.extract = V[:, -1, None].T
self.probe[0].weight.data = V[:, -1, None].T

if self.config.optimizer == "lbfgs":
loss = self.train_loop_lbfgs(x_neg, x_pos)
Expand Down
2 changes: 1 addition & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.linear = torch.nn.Linear(
input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype
)
self.linear.bias.extract.zero_()
self.linear.bias.data.zero_()
self.linear.weight.data.zero_()

def forward(self, x: Tensor) -> Tensor:
Expand Down
6 changes: 2 additions & 4 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Sweep:

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
extract=Extract(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
Expand Down Expand Up @@ -132,9 +132,7 @@ def execute(self):
out_dir = sweep_dir / model / dataset_str

data = replace(
self.run_template.extract,
model=model,
datasets=train_datasets,
self.run_template.data, model=model, datasets=train_datasets
)
run = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_classifier_roughly_same_sklearn():
)
# check that the weights are roughly the same
sklearn_coef = torch.from_numpy(model.coef_)
torch_coef = classifier.linear.weight.extract
torch_coef = classifier.linear.weight.data
torch.testing.assert_close(sklearn_coef, torch_coef, atol=1e-2, rtol=1e-2)

# check that on a new sample, the predictions are roughly the same
Expand Down
4 changes: 2 additions & 2 deletions tests/test_smoke_elicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
extract=Extract(
data=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
extract=Extract(
data=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_smoke_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup_elicit(
Returns the elicit run configuration.
"""
elicit = Elicit(
extract=Extract(
data=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -54,7 +54,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
Returns a reference time (in seconds) for file modification checking.
"""
tmp_path = elicit.out_dir
extract = elicit.extract
extract = elicit.data
assert tmp_path is not None

# record elicit modification time as reference.
Expand All @@ -64,7 +64,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
# update datasets to a different dataset
extract.datasets = transfer_datasets

eval = Eval(extract=extract, source=tmp_path)
eval = Eval(data=extract, source=tmp_path)
eval.execute()
return start_time_sec

Expand Down