Skip to content

Commit

Permalink
Fix naming issue
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 22, 2023
1 parent 761c82d commit f29743b
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def train_reporter(
cfg: RunConfig,
dataset: DatasetDict,
out_dir: Path,
layer: list[int],
layers: list[int],
devices: list[str],
world_size: int = 1,
):
"""Train a single reporter on a single layer, or a list of layers."""

# Reproducibility
seed = cfg.net.seed + layer[0]
seed = cfg.net.seed + layers[0]
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -96,17 +96,17 @@ def train_reporter(
train_labels = cast(Tensor, train["label"])
val_labels = cast(Tensor, val["label"])

# concatenate hidden states across layers if multiple layers are inputted
# Concatenate hidden states across layers if multiple layers are requested
train_hiddens = torch.cat(
[cast(Tensor, train[f"hidden_{lay}"]) for lay in layer], dim=-1
[assert_type(Tensor, train[f"hidden_{layer}"]) for layer in layers], dim=-1
)
val_hiddens = torch.cat(
[cast(Tensor, val[f"hidden_{lay}"]) for lay in layer], dim=-1
[assert_type(Tensor, val[f"hidden_{lay}"]) for lay in layers], dim=-1
)

train_h, val_h = normalize(
int16_to_float32(assert_type(Tensor, train_hiddens)),
int16_to_float32(assert_type(Tensor, val_hiddens)),
int16_to_float32(train_hiddens),
int16_to_float32(val_hiddens),
method=cfg.normalization,
)

Expand All @@ -119,7 +119,7 @@ def train_reporter(
)
if pseudo_auroc > 0.6:
warnings.warn(
f"The pseudo-labels at layers {layer} are linearly separable with "
f"The pseudo-labels at layers {layers} are linearly separable with "
f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the "
f"algorithm will not converge to a good solution."
)
Expand All @@ -143,7 +143,7 @@ def train_reporter(

lr_dir.mkdir(parents=True, exist_ok=True)
reporter_dir.mkdir(parents=True, exist_ok=True)
layer_name = layer if isinstance(layer, int) else max(layer)
layer_name = max(layers)
stats = [layer_name, pseudo_auroc, train_loss, *val_result]

if not cfg.skip_baseline:
Expand All @@ -170,10 +170,10 @@ def train_reporter(
lr_auroc = roc_auc_score(val_labels_aug, lr_preds)

stats += [lr_auroc, lr_acc]
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
with open(lr_dir / f"layer_{layers}.pt", "wb") as file:
pickle.dump(lr_model, file)

with open(reporter_dir / f"layer_{layer}.pt", "wb") as file:
with open(reporter_dir / f"layer_{layers}.pt", "wb") as file:
torch.save(reporter, file)

return stats
Expand Down

0 comments on commit f29743b

Please sign in to comment.