Skip to content

Commit

Permalink
Major refactor of CCS class
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Feb 3, 2023
1 parent 8772c2d commit 71d122b
Show file tree
Hide file tree
Showing 15 changed files with 220 additions and 368 deletions.
8 changes: 4 additions & 4 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"--model", "deberta-v2-xxlarge-mnli",
"--dataset","imdb",
"--prefix", "confusion",
"--model_device", "cuda",
"--device", "cuda",
"--num_data", "1000"
]
},
Expand All @@ -30,7 +30,7 @@
"--model", "deberta-v2-xxlarge-mnli",
"--dataset","imdb",
"--prefix", "normal",
"--model_device", "cuda",
"--device", "cuda",
"--num_data", "1000"
]
},
Expand All @@ -45,7 +45,7 @@
"--model", "deberta-v2-xxlarge-mnli",
"--dataset","imdb",
"--prefix", "normal",
"--model_device", "cuda",
"--device", "cuda",
"--num_data", "1000"
]
},
Expand All @@ -60,7 +60,7 @@
"--model", "deberta-v2-xxlarge-mnli",
"--dataset","imdb",
"--prefix", "normal",
"--model_device", "cuda",
"--device", "cuda",
"--num_data", "1000"
]
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Furthermore:
1. To generate the hidden states for one model `mdl` and all datasets, `cd elk` and then run

```bash
python generation_main.py --model deberta-v2-xxlarge-mnli --datasets imdb --prefix normal --model_device cuda --num_data 1000
python generation_main.py --model deberta-v2-xxlarge-mnli --datasets imdb --prefix normal --device cuda --num_data 1000
```

To test `deberta-v2-xxlarge-mnli` with the misleading prefix, and only the `imdb` and `amazon-polarity` datasets, while printing extra information, run:
Expand Down
36 changes: 16 additions & 20 deletions elk/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import os
import pickle
import numpy as np
import pandas as pd

from elk.utils_evaluation.ccs import CCS
from elk.utils_evaluation.utils_evaluation import (
get_hidden_states,
get_permutation,
Expand All @@ -12,10 +8,14 @@
from elk.utils_evaluation.parser import get_args
from elk.utils_evaluation.utils_evaluation import save_df_to_csv
from pathlib import Path
import numpy as np
import pandas as pd
import pickle
import torch


def evaluate(args, logistic_regression_model, ccs_model):
os.makedirs(args.save_dir, exist_ok=True)
def evaluate(args, logistic_regression_model, ccs_model: CCS):
args.save_dir.mkdir(parents=True, exist_ok=True)

hidden_states = get_hidden_states(
hidden_states_directory=args.hidden_states_directory,
Expand All @@ -34,24 +34,22 @@ def evaluate(args, logistic_regression_model, ccs_model):
accuracies_lr = []
losses_lr = []
for prompt_idx in range(len(hidden_states)):
data, labels = split(
features, labels = split(
hidden_states=hidden_states,
permutation=permutation,
prompts=[prompt_idx],
split="test",
)

# evaluate classification model
print("evaluate classification model")
acc_lr = logistic_regression_model.score(data, labels)
print("Evaluating logistic regression model")
acc_lr = logistic_regression_model.score(features, labels)
accuracies_lr.append(acc_lr)
losses_lr.append(0) # TODO: get loss from lr somehow

# evaluate ccs model
print("evaluate ccs model")
half = data.shape[1] // 2
data = [data[:, :half], data[:, half:]]
acc_ccs, loss_ccs = ccs_model.score(data, labels, getloss=True)
print("Evaluating CCS model")
x0, x1 = torch.from_numpy(features).to(args.device).chunk(2, dim=1)
labels = torch.from_numpy(labels).to(args.device)
acc_ccs, loss_ccs = ccs_model.score((x0, x1), labels)
accuracies_ccs.append(acc_ccs)
losses_ccs.append(loss_ccs)

Expand Down Expand Up @@ -89,16 +87,14 @@ def evaluate(args, logistic_regression_model, ccs_model):
stats_df = append_stats(
stats_df, args, "lr", avg_accuracy_lr, avg_accuracy_std_lr, avg_loss_lr
)
save_df_to_csv(args, stats_df, args.prefix, "After finish")
save_df_to_csv(args, stats_df, args.prefix)


if __name__ == "__main__":
args = get_args(default_config_path=Path(__file__).parent / "default_config.json")

# load pickel from file
with open(args.trained_models_path / "logistic_regression_model.pkl", "rb") as file:
logistic_regression_model = pickle.load(file)
with open(args.trained_models_path / "ccs_model.pkl", "rb") as file:
ccs_model = pickle.load(file)

ccs_model = CCS.load(args.trained_models_path / "ccs_model.pt")
evaluate(args, logistic_regression_model, ccs_model)
2 changes: 1 addition & 1 deletion elk/generate.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

# run in background
nohup python -m elk.generation_main --model deberta-v2-xxlarge-mnli --datasets imdb --prefix normal --model_device cuda --num_data 1000 &
nohup python -m elk.generation_main --model deberta-v2-xxlarge-mnli --datasets imdb --prefix normal --device cuda --num_data 1000 &
ps -ax | grep generation_main
4 changes: 1 addition & 3 deletions elk/generation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
"finish loading model to memory. Now start loading to accelerator (gpu or"
f" mps). parallelize = {args.parallelize is True}"
)
model = put_model_on_device(
model, parallelize=args.parallelize, device=args.model_device
)
model = put_model_on_device(model, parallelize=args.parallelize, device=args.device)

print(
f"loading tokenizer for: model name = {args.model} at cache_dir ="
Expand Down
12 changes: 2 additions & 10 deletions elk/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os
import random

import numpy as np
import pickle

from train import train
Expand All @@ -11,18 +7,14 @@

if __name__ == "__main__":
args = get_args(default_config_path=Path(__file__).parent / "default_config.json")
os.makedirs(args.trained_models_path, exist_ok=True)

random.seed(args.seed)
np.random.seed(args.seed)
args.trained_models_path.mkdir(parents=True, exist_ok=True)

logistic_regression_model, ccs_model = train(args)

# save models
# TODO: use better filename for the pkls, so they don't get overwritten
with open(args.trained_models_path / "logistic_regression_model.pkl", "wb") as file:
pickle.dump(logistic_regression_model, file)
with open(args.trained_models_path / "ccs_model.pkl", "wb") as file:
pickle.dump(ccs_model, file)

ccs_model.save(args.trained_models_path / "ccs_model.pt")
evaluate(args, logistic_regression_model, ccs_model)
57 changes: 32 additions & 25 deletions elk/train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import pickle

from sklearn.linear_model import LogisticRegression

from pathlib import Path
from elk.utils_evaluation.ccs import CCS
from elk.utils_evaluation.utils_evaluation import (
get_hidden_states,
get_permutation,
split,
)
from elk.utils_evaluation.parser import get_args
from pathlib import Path
from sklearn.linear_model import LogisticRegression
import numpy as np
import pickle
import random
import torch


def train(args):
# Reproducibility
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)

# Extract hidden states from the model
hidden_states = get_hidden_states(
hidden_states_directory=args.hidden_states_directory,
model_name=args.model,
Expand All @@ -24,35 +31,35 @@ def train(args):
num_data=args.num_data,
)

# Set the random seed for the permutation
permutation = get_permutation(hidden_states)
data, labels = split(
# `features` is of shape [batch_size, hidden_size * 2]
# the first half of the features are from the first sentence,
# the second half are from the second sentence
features, labels = split(
hidden_states,
permutation,
get_permutation(hidden_states),
prompts=range(len(hidden_states)),
split="train",
)
assert len(data.shape) == 2
assert len(features.shape) == 2

print("train classification model")
# TODO: Once we implement cross-validation for CCS, we should benchmark it against
# LogisticRegressionCV here.
print("Fitting logistic regression model...")
logistic_regression_model = LogisticRegression(max_iter=10000, n_jobs=1, C=0.1)
logistic_regression_model.fit(data, labels)
print("done training classification model")
logistic_regression_model.fit(features, labels)
print("Done.")

print("train ccs model")
half = data.shape[1] // 2
data = [data[:, :half], data[:, half:]]
d = data[0].shape[1]
print("Training CCS model...")
x0, x1 = torch.from_numpy(features).to(args.device).chunk(2, dim=1)

ccs_model = CCS(hidden_size=d)
ccs_model = CCS(in_features=features.shape[1] // 2, device=args.device)
ccs_model.fit(
data=data,
label=labels,
data=(x0, x1),
optimizer=args.optimizer,
verbose=True,
weight_decay=args.weight_decay,
)
print("done training ccs model")
print("Done.")

return logistic_regression_model, ccs_model

Expand All @@ -64,9 +71,9 @@ def train(args):
logistic_regression_model, ccs_model = train(args)

# save models
# TODO: use better filename for the pkls, so they don't get overwritten
Path(args.trained_models_path).mkdir(parents=True, exist_ok=True)
# TODO: use better filenames for the pkls, so they don't get overwritten
args.trained_models_path.mkdir(parents=True, exist_ok=True)
with open(args.trained_models_path / "logistic_regression_model.pkl", "wb") as file:
pickle.dump(logistic_regression_model, file)
with open(args.trained_models_path / "ccs_model.pkl", "wb") as file:
pickle.dump(ccs_model, file)

ccs_model.save(args.trained_models_path / "ccs_model.pt")
Loading

0 comments on commit 71d122b

Please sign in to comment.