From 533401354a9b455e79550d09fd6be0c5f46ec1de Mon Sep 17 00:00:00 2001 From: Adam Scherlis Date: Tue, 4 Jun 2024 08:05:50 +0000 Subject: [PATCH] act shape bug etc --- w2s/sft.py | 7 ++++++- w2s/sft_utils.py | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/w2s/sft.py b/w2s/sft.py index fc93d98..397bd20 100644 --- a/w2s/sft.py +++ b/w2s/sft.py @@ -114,6 +114,8 @@ def train( model, tokenizer = init_model_and_tokenizer(model_cfg) + print(f"{get_gpu_mem_used() * 100:.2f}% of all GPU memory in use after model init") + def process(examples): out = tokenizer(examples["txt"], truncation=True) return out @@ -149,7 +151,8 @@ def compute_metrics(eval_pred): probe = PROBES[cfg["probe_name"]](cfg["probe"]) probe.fit(acts, torch.tensor(ds_dict["train"]["labels"])) for name, ds in ds_dict.items(): - acts = torch.load(acts_dir / f"{name}.pt", map_location=model.device) + all_acts = torch.load(acts_dir / f"{name}.pt", map_location=model.device) + acts = all_acts[:, cfg["num_hidden_layers"] // 2] preds = probe.predict(acts) agree_metrics = compute_metrics_torch(preds, torch.tensor(ds["labels"])) gt_metrics = compute_metrics_torch(preds, torch.tensor(ds["gt_labels"])) @@ -177,6 +180,8 @@ def compute_metrics(eval_pred): f"Results already exist at {results_path}. Skipping training and evaluation." ) return + else: + print(f"No results found at {results_path}. Training model.") if transfer and cfg["loss_name"] == "window" and cfg["loss"].radius == "midweak": confs = torch.abs(torch.tensor(ds_dict["train"]["labels"]) - 0.5) diff --git a/w2s/sft_utils.py b/w2s/sft_utils.py index 7bf260c..1fd5163 100644 --- a/w2s/sft_utils.py +++ b/w2s/sft_utils.py @@ -50,12 +50,16 @@ def gather_hiddens(model: torch.nn.Module, dataset: Dataset): D = assert_type(int, cfg.hidden_size) L = assert_type(int, cfg.num_hidden_layers) - buffer = torch.empty(L, len(dataset), D, device=model.device, dtype=model.dtype) + buffer = torch.empty(len(dataset), L, D, device=model.device, dtype=model.dtype) + print(f"Allocated buffer of shape {buffer.shape}") for i, ex in enumerate(tqdm(dataset)): ex = assert_type(dict, ex) out = model(ex["input_ids"][None], output_hidden_states=True) - buffer[i] = torch.stack(out.hidden_states)[:, 0, -1] # Final token + act = torch.stack(out.hidden_states)[:, 0, -1] # Final token + if act.shape != (L, D): + raise ValueError(f"Unexpected shape {act.shape} for hidden states on example {i}") + buffer[i] = act return buffer