Skip to content

Commit

Permalink
change variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 10, 2023
1 parent 179e6b7 commit f2cdbb1
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __post_init__(self):
# Set our output directory before super().execute() does
if not self.out_dir:
root = elk_reporter_dir() / self.source
self.out_dir = root / "transfer" / "+".join(self.data.datasets)
self.out_dir = root / "transfer" / "+".join(self.extract.datasets)

def execute(self, highlight_color: Color = "cyan"):
return super().execute(highlight_color, split_type="val")
Expand Down
8 changes: 4 additions & 4 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@dataclass
class Run(ABC, Serializable):
data: Extract
extract: Extract
out_dir: Path | None = None
"""Directory to save results to. If None, a directory will be created
automatically."""
Expand Down Expand Up @@ -67,14 +67,14 @@ def execute(
min_gpu_mem=self.min_gpu_mem,
split_type=split_type,
)
for cfg in self.data.explode()
for cfg in self.extract.explode()
]

if self.out_dir is None:
# Save in a memorably-named directory inside of
# ELK_REPORTER_DIR/<model_name>/<dataset_name>
ds_name = "+".join(self.data.datasets)
root = elk_reporter_dir() / self.data.model / ds_name
ds_name = "+".join(self.extract.datasets)
root = elk_reporter_dir() / self.extract.model / ds_name

self.out_dir = memorably_named_dir(root)

Expand Down
4 changes: 2 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def reset_parameters(self):
theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device)
theta /= theta.norm()
probe.weight.data = theta[:, :-1]
probe.bias.data = theta[:, -1]
probe.bias.extract = theta[:, -1]

elif self.config.init == "default":
for layer in self.probe:
Expand Down Expand Up @@ -285,7 +285,7 @@ def fit(self, hiddens: Tensor) -> float:
if self.config.init == "pca":
diffs = torch.flatten(x_pos - x_neg, 0, 1)
_, __, V = torch.pca_lowrank(diffs, q=i + 1)
self.probe[0].weight.data = V[:, -1, None].T
self.probe[0].weight.extract = V[:, -1, None].T

if self.config.optimizer == "lbfgs":
loss = self.train_loop_lbfgs(x_neg, x_pos)
Expand Down
2 changes: 1 addition & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.linear = torch.nn.Linear(
input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype
)
self.linear.bias.data.zero_()
self.linear.bias.extract.zero_()
self.linear.weight.data.zero_()

def forward(self, x: Tensor) -> Tensor:
Expand Down
12 changes: 8 additions & 4 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Sweep:

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
extract=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
Expand Down Expand Up @@ -132,7 +132,9 @@ def execute(self):
out_dir = sweep_dir / model / dataset_str

data = replace(
self.run_template.data, model=model, datasets=train_datasets
self.run_template.extract,
model=model,
datasets=train_datasets,
)
run = replace(self.run_template, data=data, out_dir=out_dir)
if var_weight is not None and neg_cov_weight is not None:
Expand Down Expand Up @@ -164,8 +166,10 @@ def execute(self):
assert run.out_dir is not None
# TODO we should fix this so that this isn't needed
eval = Eval(
data=replace(
run.data, model=model, datasets=(eval_dataset,)
extract=replace(
run.extract,
model=model,
datasets=(eval_dataset,),
),
source=run.out_dir,
out_dir=run.out_dir / "transfer" / eval_dataset,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_classifier_roughly_same_sklearn():
)
# check that the weights are roughly the same
sklearn_coef = torch.from_numpy(model.coef_)
torch_coef = classifier.linear.weight.data
torch_coef = classifier.linear.weight.extract
torch.testing.assert_close(sklearn_coef, torch_coef, atol=1e-2, rtol=1e-2)

# check that on a new sample, the predictions are roughly the same
Expand Down
4 changes: 2 additions & 2 deletions tests/test_smoke_elicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path):
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_smoke_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup_elicit(
Returns the elicit run configuration.
"""
elicit = Elicit(
data=Extract(
extract=Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
Expand Down Expand Up @@ -54,7 +54,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
Returns a reference time (in seconds) for file modification checking.
"""
tmp_path = elicit.out_dir
extract = elicit.data
extract = elicit.extract
assert tmp_path is not None

# record elicit modification time as reference.
Expand All @@ -64,7 +64,7 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float:
# update datasets to a different dataset
extract.datasets = transfer_datasets

eval = Eval(data=extract, source=tmp_path)
eval = Eval(extract=extract, source=tmp_path)
eval.execute()
return start_time_sec

Expand Down

0 comments on commit f2cdbb1

Please sign in to comment.