forked from centerforaisafety/HarmBench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_test_cases.py
118 lines (100 loc) · 5.34 KB
/
generate_test_cases.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import transformers
from baselines import init_method
import yaml
import argparse
import csv
import torch
import os
from os.path import join
from eval_utils import get_experiment_config
import pandas as pd
# Set this to disable warning messages in the generation mode.
transformers.utils.logging.set_verbosity_error()
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Running red teaming with baseline methods.")
parser.add_argument("--method_name", type=str, default='GCG',
help="The name of the red teaming method")
parser.add_argument("--experiment_name", type=str,
help="The name of the experiment (options are in configs/methods_config/{{method_name}}_config.yaml)")
parser.add_argument("--method_config_file", type=str, default=None,
help="The path to the config file with method hyperparameters. This will default to configs/method_configs/{{method_name}}_config.yaml")
parser.add_argument("--models_config_file", type=str, default='./configs/model_configs/models.yaml',
help="The path to the config file with model hyperparameters")
parser.add_argument("--behaviors_path", type=str, default='./data/400_behaviors.csv',
help="The path to the behaviors file")
parser.add_argument("--save_dir", type=str, default='./test_cases',
help="The directory used for saving test cases")
parser.add_argument("--behavior_start_idx", type=int, default=None,
help="Start index for behaviors_path data")
parser.add_argument("--behavior_end_idx", type=int, default=None,
help="End index for behaviors_path data")
parser.add_argument("--run_id", type=str, default=None,
help="An optional run id for output files, will be appended to output filenames")
parser.add_argument("--overwrite", action="store_true",
help="Whether to overwrite existing saved test cases")
parser.add_argument("--verbose", action="store_true",
help="Whether to print out intermediate results")
args = parser.parse_args()
return args
def main():
# ========== load arguments and config ========== #
args = parse_args()
method_name = args.method_name
experiment_name = args.experiment_name
save_dir = os.path.abspath(args.save_dir)
# Load model config file
config_file = f"configs/model_configs/models.yaml" if not args.models_config_file else args.models_config_file
with open(config_file) as file:
model_configs = yaml.full_load(file)
# Load method config file
config_file = f"configs/method_configs/{method_name}_config.yaml" if not args.method_config_file else args.method_config_file
with open(config_file) as file:
method_configs = yaml.full_load(file)
# Load default method parameters
if 'default_method_hyperparameters' not in method_configs:
raise ValueError(f"Default method hyperparameters not found in {config_file}")
method_config = method_configs.get('default_method_hyperparameters') or {}
# Update method parameters with experiment parameters
experiment_config = get_experiment_config(experiment_name, model_configs, method_configs)
method_config.update(experiment_config)
# ========== generate test cases ========== #
# load behaviors csv
with open(args.behaviors_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
behaviors = list(reader)
behavior_start_idx, behavior_end_idx = args.behavior_start_idx, args.behavior_end_idx
if args.behavior_start_idx is not None and args.behavior_end_idx is not None:
print(f"Selecting subset from [{behavior_start_idx}, {behavior_end_idx}]")
behaviors = behaviors[args.behavior_start_idx:args.behavior_end_idx+1]
world_size = torch.cuda.device_count()
print(f"============={method_name} Method Config=============")
method_config['method_name'] = method_name
method_config['experiment_name'] = experiment_name
print('\n'.join([f"{k}: {v}" for k, v in method_config.items()]))
print("=============GPU Allocation Configs=============")
print(f"Total GPUs: {world_size}")
print("=============")
# ==== Init Method ====
method = init_method(method_name, method_config)
# ==== Filtering existed runs =====
if not args.overwrite:
filtered_behaviors = []
for behavior in behaviors:
# flatten csv row
behavior_id = behavior['BehaviorID']
test_case_file = method.get_output_file_path(save_dir, behavior_id, file_type='test_cases', run_id=args.run_id)
if not os.path.exists(test_case_file):
filtered_behaviors.append(behavior)
else:
print(f"===>Skipping {behavior_id}")
behaviors = filtered_behaviors
# ====== Run attack ======
print("#Behaviors to run:" , len(behaviors))
all_test_cases, all_logs = method.generate_test_cases(behaviors=behaviors, verbose=args.verbose)
# ====== Save test cases and logs ======
print("Saving outputs to: ", save_dir)
method.save_test_cases(save_dir, all_test_cases, all_logs, method_config=method_config, run_id=args.run_id)
print("Finished")
if __name__ == "__main__":
main()