Skip to content

Commit

Permalink
act shape bug etc
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 4, 2024
1 parent 6adbe82 commit 5334013
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
7 changes: 6 additions & 1 deletion w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def train(

model, tokenizer = init_model_and_tokenizer(model_cfg)

print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after model init")

def process(examples):
out = tokenizer(examples["txt"], truncation=True)
return out
Expand Down Expand Up @@ -149,7 +151,8 @@ def compute_metrics(eval_pred):
probe = PROBES[cfg["probe_name"]](cfg["probe"])
probe.fit(acts, torch.tensor(ds_dict["train"]["labels"]))
for name, ds in ds_dict.items():
acts = torch.load(acts_dir / f"{name}.pt", map_location=model.device)
all_acts = torch.load(acts_dir / f"{name}.pt", map_location=model.device)
acts = all_acts[:, cfg["num_hidden_layers"] // 2]
preds = probe.predict(acts)
agree_metrics = compute_metrics_torch(preds, torch.tensor(ds["labels"]))
gt_metrics = compute_metrics_torch(preds, torch.tensor(ds["gt_labels"]))
Expand Down Expand Up @@ -177,6 +180,8 @@ def compute_metrics(eval_pred):
f"Results already exist at {results_path}. Skipping training and evaluation."
)
return
else:
print(f"No results found at {results_path}. Training model.")

if transfer and cfg["loss_name"] == "window" and cfg["loss"].radius == "midweak":
confs = torch.abs(torch.tensor(ds_dict["train"]["labels"]) - 0.5)
Expand Down
8 changes: 6 additions & 2 deletions w2s/sft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,16 @@ def gather_hiddens(model: torch.nn.Module, dataset: Dataset):
D = assert_type(int, cfg.hidden_size)
L = assert_type(int, cfg.num_hidden_layers)

buffer = torch.empty(L, len(dataset), D, device=model.device, dtype=model.dtype)
buffer = torch.empty(len(dataset), L, D, device=model.device, dtype=model.dtype)
print(f"Allocated buffer of shape {buffer.shape}")
for i, ex in enumerate(tqdm(dataset)):
ex = assert_type(dict, ex)

out = model(ex["input_ids"][None], output_hidden_states=True)
buffer[i] = torch.stack(out.hidden_states)[:, 0, -1] # Final token
act = torch.stack(out.hidden_states)[:, 0, -1] # Final token
if act.shape != (L, D):
raise ValueError(f"Unexpected shape {act.shape} for hidden states on example {i}")
buffer[i] = act

return buffer

Expand Down

0 comments on commit 5334013

Please sign in to comment.