Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix argument passthrough for sweep #266

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
change variable name
  • Loading branch information
derpyplops committed Jul 12, 2023
commit aa1dc8861e935dffaf047846d4abc831e3769b86
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __post_init__(self):
# Set our output directory before super().execute() does
if not self.out_dir:
root = elk_reporter_dir() / self.source
self.out_dir = root / "transfer" / "+".join(self.data.datasets)
self.out_dir = root / "transfer" / "+".join(self.extract.datasets)

def execute(self, highlight_color: Color = "cyan"):
return super().execute(highlight_color, split_type="val")
Expand Down
8 changes: 4 additions & 4 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@dataclass
class Run(ABC, Serializable):
data: Extract
extract: Extract
out_dir: Path | None = None
"""Directory to save results to. If None, a directory will be created
automatically."""
Expand Down Expand Up @@ -67,14 +67,14 @@ def execute(
min_gpu_mem=self.min_gpu_mem,
split_type=split_type,
)
for cfg in self.data.explode()
for cfg in self.extract.explode()
]

if self.out_dir is None:
# Save in a memorably-named directory inside of
# ELK_REPORTER_DIR/<model_name>/<dataset_name>
ds_name = "+".join(self.data.datasets)
root = elk_reporter_dir() / self.data.model / ds_name
ds_name = "+".join(self.extract.datasets)
root = elk_reporter_dir() / self.extract.model / ds_name

self.out_dir = memorably_named_dir(root)

Expand Down
4 changes: 2 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def reset_parameters(self):
theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device)
theta /= theta.norm()
probe.weight.data = theta[:, :-1]
probe.bias.data = theta[:, -1]
probe.bias.extract = theta[:, -1]

elif self.config.init == "default":
for layer in self.probe:
Expand Down Expand Up @@ -219,7 +219,7 @@ def fit(self, hiddens: Tensor) -> float:
if self.config.init == "pca":
diffs = torch.flatten(x_pos - x_neg, 0, 1)
_, __, V = torch.pca_lowrank(diffs, q=i + 1)
self.probe[0].weight.data = V[:, -1, None].T
self.probe[0].weight.extract = V[:, -1, None].T

if self.config.optimizer == "lbfgs":
loss = self.train_loop_lbfgs(x_neg, x_pos)
Expand Down
2 changes: 1 addition & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.linear = torch.nn.Linear(
input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype
)
self.linear.bias.data.zero_()
self.linear.bias.extract.zero_()
self.linear.weight.data.zero_()

def forward(self, x: Tensor) -> Tensor:
Expand Down
12 changes: 8 additions & 4 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Sweep:

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
extract=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
Expand Down Expand Up @@ -132,7 +132,9 @@ def execute(self):
out_dir = sweep_dir / model / dataset_str

data = replace(
self.run_template.data, model=model, datasets=train_datasets
self.run_template.extract,
model=model,
datasets=train_datasets,
)
run = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
Expand Down Expand Up @@ -164,8 +166,10 @@ def execute(self):
assert run.out_dir is not None
# TODO we should fix this so that this isn't needed
eval = Eval(
data=replace(
run.data, model=model, datasets=(eval_dataset,)
extract=replace(
run.extract,
model=model,
datasets=(eval_dataset,),
),
source=run.out_dir,
out_dir=run.out_dir / "transfer" / eval_dataset,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_classifier_roughly_same_sklearn():
)
# check that the weights are roughly the same
sklearn_coef = torch.from_numpy(model.coef_)
torch_coef = classifier.linear.weight.data
torch_coef = classifier.linear.weight.extract
torch.testing.assert_close(sklearn_coef, torch_coef, atol=1e-2, rtol=1e-2)

# check that on a new sample, the predictions are roughly the same
Expand Down
4 changes: 2 additions & 2 deletions tests/test_smoke_elicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_smoke_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup_elicit(
Returns the elicit run configuration.
"""
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -54,7 +54,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
Returns a reference time (in seconds) for file modification checking.
"""
tmp_path = elicit.out_dir
extract = elicit.data
extract = elicit.extract
assert tmp_path is not None

# record elicit modification time as reference.
Expand All @@ -64,7 +64,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
# update datasets to a different dataset
extract.datasets = transfer_datasets

eval = Eval(data=extract, source=tmp_path)
eval = Eval(extract=extract, source=tmp_path)
eval.execute()
return start_time_sec

Expand Down