-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added scripts for language agency classifier
- Loading branch information
1 parent
748e920
commit 73adac2
Showing
22 changed files
with
1,040 additions
and
459 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Data Generation, Data Preprocessing, and Training Scripts For the Language Agency Classifier | ||
Refer to the following instructions to generate the training data and training the Language Agency Classifier for our usage. | ||
## Data Generation and Preprocessing | ||
You may refer to the following steps to generate and preprocess the languae agency classification dataset. Alternatively, access our generated and preprocssed dataset stored in `/agency_classifier/agency_dataset/` | ||
1. Generate the raw language agency classification dataset by prompting ChatGPT to rephrase a piece of original biography into an agentic version and a communal version. To run generation, , first add in your OpenAI organization and API key in `/agency_classifier/agency_generation_util.py`. Use the following command to run generation: | ||
``` | ||
cd agency_classifier | ||
sh run_generate.sh | ||
``` | ||
The generated raw dataset will be stored in `/agency_classifier/agency_bios/BIOS_sampled_preprocessed.csv` | ||
|
||
2. Then, split the generated raw file into train, test, and validation sets: | ||
``` | ||
# Make sure you are still in the agency_classifier directory | ||
sh run_split.sh | ||
``` | ||
The processed datasets will be stored in `/agency_classifier/agency_dataset/` | ||
## Training the Language Agency Classifier | ||
You may refer to the following command to train the language agency classifier using the generated dataset. Alternatively, access our trained classifier checkpoint in Google Drive and place it under the `/agency_classifier/checkpoints/` directory. | ||
To train the language agency classifier, run: | ||
``` | ||
# Make sure you are still in the agency_classifier directory | ||
sh run_train.sh | ||
``` |
Large diffs are not rendered by default.
Oops, something went wrong.
101 changes: 101 additions & 0 deletions
101
agency_classifier/agency_bios/BIOS_sampled_preprocessed.csv
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import re | ||
import random | ||
import torch | ||
import openai | ||
from ratelimiter import RateLimiter | ||
from retrying import retry | ||
|
||
AGENCY_DATASET_GEN_PROMPTS = { | ||
'You will rephrase a biography two times to demonstrate agentic and communal language traits respectively. "agentic" is defined as more achievement-oriented, and "communal" is defined as more social or service-oriented. The paragraph is: "{}"' | ||
} | ||
|
||
# # Uncomment this part and fill in your OpenAI organization and API key to query ChatGPT's API | ||
# openai.organization = $YOUR_ORGANIZATION$ | ||
# openai.api_key = $YOUR_API_KEY$ | ||
|
||
# To avoid exceeding rate limit for ChatGPT API | ||
@retry(stop_max_attempt_number=10) | ||
@RateLimiter(max_calls=20, period=60) | ||
def generate_response_fn(utt): | ||
prompt = random.sample(AGENCY_DATASET_GEN_PROMPTS, 1)[0] # .format(utt) | ||
utt = " ".join([prompt, utt]) | ||
response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", messages=[{"role": "user", "content": utt}] | ||
) | ||
# print('ChatGPT: {}'.format(response["choices"][0]["message"]["content"].strip())) | ||
return response["choices"][0]["message"]["content"].strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
import torch | ||
import pandas as pd | ||
import datasets | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModelForSequenceClassification, | ||
Trainer, | ||
TrainingArguments, | ||
) | ||
from argparse import ArgumentParser | ||
|
||
# disable wandb | ||
os.environ["WANDB_DISABLED"] = "true" | ||
|
||
# Define the data collator | ||
def data_collator(data): | ||
return { | ||
"input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in data]), | ||
"attention_mask": torch.stack( | ||
[torch.tensor(f["attention_mask"]) for f in data] | ||
), | ||
"labels": torch.tensor([f["label"] for f in data]), | ||
} | ||
|
||
# Define the compute_metrics function | ||
def compute_metrics(eval_preds): | ||
labels = eval_preds.label_ids | ||
preds = eval_preds.predictions.argmax(-1) | ||
acc = acc_metric.compute(predictions=preds, references=labels) | ||
precision = precision_metric.compute(predictions=preds, references=labels) | ||
recall = recall_metric.compute(predictions=preds, references=labels) | ||
f1 = f1_metric.compute(predictions=preds, references=labels) | ||
return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1} | ||
|
||
if __name__ == "__main__": | ||
# Configuration | ||
parser = ArgumentParser() | ||
parser.add_argument('-d', '--dataset_path', default='./agency_dataset/', required=False) | ||
parser.add_argument('-m', '--model_type', default="bert-base-uncased", required=False) | ||
parser.add_argument('-cp', '--checkpoint_path', default='./checkpoints', required=False) | ||
parser.add_argument('-lr', '--learning_rate', default=2e-5, required=False) | ||
parser.add_argument('-e', '--num_epochs', default=10, required=False) | ||
parser.add_argument('-tb', '--train_bsz', default=8, required=False) | ||
parser.add_argument('-eb', '--eval_bsz', default=16, required=False) | ||
parser.add_argument('-wd', '--weight_decay', default=0.01, required=False) | ||
args = parser.parse_args() | ||
|
||
train_df = pd.read_csv(args.dataset_path + "train.csv") | ||
train_df = train_df.sample(frac=1).reset_index(drop=True) | ||
print('\n Length:', len(train_df)) | ||
train_dataset = datasets.Dataset.from_pandas(train_df) | ||
|
||
val_df = pd.read_csv(args.dataset_path + "val.csv") | ||
val_df = val_df.sample(frac=1).reset_index(drop=True) | ||
val_dataset = datasets.Dataset.from_pandas(val_df) | ||
|
||
test_df = pd.read_csv(args.dataset_path + "test.csv") | ||
test_df = test_df.sample(frac=1).reset_index(drop=True) | ||
test_dataset = datasets.Dataset.from_pandas(test_df) | ||
|
||
acc_metric = datasets.load_metric("accuracy") | ||
precision_metric = datasets.load_metric("precision") | ||
recall_metric = datasets.load_metric("recall") | ||
f1_metric = datasets.load_metric("f1") | ||
|
||
# Load the BERT tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_type) | ||
|
||
# Tokenize the dataset | ||
def tokenize_function(examples): | ||
return tokenizer(examples["text"], padding="max_length", truncation=True) | ||
|
||
|
||
# tokenized_dataset = dataset.map(tokenize_function, batched=True) | ||
tokenized_train_dataset = train_dataset.map( | ||
tokenize_function, batched=True, batch_size=len(train_dataset) | ||
) | ||
tokenized_val_dataset = val_dataset.map( | ||
tokenize_function, batched=True, batch_size=len(val_dataset) | ||
) | ||
|
||
tokenized_test_dataset = test_dataset.map( | ||
tokenize_function, batched=True, batch_size=len(test_dataset) | ||
) | ||
|
||
# Load the BERT model | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
args.model_type, num_labels=2 | ||
) | ||
|
||
# Define the training arguments | ||
training_args = TrainingArguments( | ||
# To turn off wandb | ||
report_to=None, | ||
output_dir=args.checkpoint_path, | ||
evaluation_strategy="epoch", | ||
learning_rate=args.learning_rate, | ||
num_train_epochs=args.num_epochs, | ||
per_device_train_batch_size=args.train_bsz, | ||
per_device_eval_batch_size=args.eval_bsz, | ||
weight_decay=args.weight_decay, | ||
save_strategy="epoch", | ||
logging_strategy="epoch", | ||
load_best_model_at_end=True, | ||
fp16=True, # enable mixed precision training | ||
gradient_accumulation_steps=2, # accumulate gradients for every 2 batches | ||
) | ||
|
||
# Define the trainer | ||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=tokenized_train_dataset, | ||
eval_dataset=tokenized_val_dataset, | ||
data_collator=data_collator, | ||
compute_metrics=compute_metrics, | ||
) | ||
|
||
# Fine-tune the model | ||
trainer.train() | ||
print('\n\n\n Testing -------------------------------- \n') | ||
trainer.evaluate(eval_dataset=tokenized_test_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
import pandas as pd | ||
import json | ||
import ast | ||
import re | ||
from agency_generation_util import generate_response_fn | ||
from tqdm import tqdm | ||
from argparse import ArgumentParser, Namespace | ||
import time | ||
|
||
if __name__ == "__main__": | ||
# Configuration | ||
parser = ArgumentParser() | ||
parser.add_argument('-d', '--dataset_path', default='./agency_bios/', required=False) # './dataset/' | ||
args = parser.parse_args() | ||
|
||
dataset = os.path.join(args.dataset_path, 'BIOS_sampled.csv') | ||
df = pd.read_csv(dataset) | ||
df['raw_gen_data'] = None | ||
for i in tqdm(range(len(df))): | ||
df['raw_gen_data'][i] = generate_response_fn(df['raw_bio'][i]) | ||
|
||
df["agentic_gen"] = None | ||
df["communal_gen"] = None | ||
for i in tqdm(range(len(df))): | ||
print(i) | ||
d = df['raw_gen_data'][i] | ||
data = d.split('{"agentic":') | ||
data = list(filter(lambda a: a != '', data))[-1] | ||
data = data.split('"communal":') | ||
data = list(filter(lambda a: a not in ['', ',','.',';','?','!','/'], data)) | ||
df["agentic_gen"][i], df["communal_gen"][i] = data[0], data[1] | ||
|
||
df.to_csv(os.path.join(args.dataset_path, 'BIOS_sampled_preprocessed.csv')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python generate_dataset.sh -d ./agency_bios |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python split_biasbios.py -if ./agency_bios/BIOS_sampled_preprocessed.csv -of ./agency_dataset/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export CUDA_VISIBLE_DEVICES=0 && python3 finetune_bert_bias.py -d ./agency_dataset/ -cp ./checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pickle | ||
import random | ||
import pandas as pd | ||
|
||
file_name = "../dataset/BIOS.pkl" | ||
output_path = "../dataset/BIOS_sampled_2.csv" | ||
MAX_COUNT = 50 | ||
with open(file_name, "rb") as f: | ||
# Load the contents of the file | ||
contents = pickle.load(f) | ||
random.shuffle(contents) | ||
|
||
# Iterate over the contents line by line | ||
saved_data = { | ||
"name": [], | ||
"gender": [], | ||
"raw_title": [], | ||
"raw_bio": [], | ||
"title": [], | ||
"bio": [], | ||
} | ||
male_count = 0 | ||
female_count = 0 | ||
|
||
for line in contents: | ||
if line["gender"] == "M": | ||
if male_count == MAX_COUNT: | ||
continue | ||
male_count += 1 | ||
else: | ||
if female_count == MAX_COUNT: | ||
continue | ||
female_count += 1 | ||
|
||
saved_data["name"].append( | ||
line["name"][0] + " " + line["name"][2] | ||
if line["name"][1] == "" | ||
else " ".join(line["name"]) | ||
) | ||
saved_data["gender"].append(line["gender"]) | ||
saved_data["raw_title"].append(line["raw_title"]) | ||
saved_data["title"].append(line["title"]) | ||
saved_data["raw_bio"].append(line["raw"]) | ||
saved_data["bio"].append(line["bio"]) | ||
|
||
data = pd.DataFrame(saved_data) | ||
data.to_csv(output_path, index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import pandas as pd | ||
from sklearn.model_selection import train_test_split | ||
from collections import Counter | ||
from argparse import ArgumentParser | ||
|
||
def load_from_bios(dataset_path, is_balanced): | ||
print("Loading data from {}".format(dataset_path)) | ||
temp_df = pd.read_csv(dataset_path) | ||
|
||
# Communal: 0, Agentic: 1 | ||
text, label = [], [] | ||
|
||
for i in range(len(temp_df)): | ||
# Adding raw_bios can cause imbalance in dataset distribution (mostly agentic) | ||
if not is_balanced: | ||
text.append(temp_df.loc[i, "raw_bio"]) | ||
label.append( | ||
0 if temp_df.loc[i, "chatgpt_eval"].lower() == "communal" else 1 | ||
) | ||
text.append(temp_df.loc[i, "communal_bio"]) | ||
label.append(0) | ||
text.append(temp_df.loc[i, "agentic_bio"]) | ||
label.append(1) | ||
|
||
df = pd.DataFrame({"text": text, "label": label}) | ||
# Shuffle dataframe | ||
df = df.sample(frac=1).reset_index(drop=True) | ||
return df | ||
|
||
def split_data(df, output_path): | ||
print("Splitting data from {}".format(output_path)) | ||
# Split dataframe into train, validation, and test sets | ||
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42) | ||
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42) | ||
|
||
# Stats | ||
print("Train distribution split: {}".format(Counter(train_df["label"].tolist()))) | ||
print("Val distribution split: {}".format(Counter(val_df["label"].tolist()))) | ||
print("Test distribution split: {}".format(Counter(test_df["label"].tolist()))) | ||
|
||
train_path = output_path + "train.csv" | ||
val_path = output_path + "val.csv" | ||
test_path = output_path + "test.csv" | ||
directory = os.path.dirname(train_path) | ||
if not os.path.exists(directory): | ||
os.makedirs(directory) | ||
|
||
train_df.to_csv(train_path, index=False) | ||
val_df.to_csv(val_path, index=False) | ||
test_df.to_csv(test_path, index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Configuration | ||
parser = ArgumentParser() | ||
parser.add_argument('-if', '--input_file', default='./agency_bios/BIOS_sampled_preprocessed.csv', required=False) | ||
parser.add_argument('-of', '--output_folder', default='./agency_dataset/', required=False) | ||
parser.add_argument('-ib', '--is_balanced', default=False, required=False) | ||
args = parser.parse_args() | ||
|
||
df = load_from_bios(args.input_file, args.is_balanced) | ||
split_data(df, args.output_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.