diff --git a/README.md b/README.md index 2103bc6..4dbfa18 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ PromptBench is a powerful tool designed to scrutinize and analyze the interactio -Check our paper: [PromptBench: Towards Evaluating the Robustness of Large Language Models on Adversarial Prompts](https://arxiv.org/abs/2306.04528). +Check our paper: [PromptBench: Towards Evaluating the Robustness of Large Language Models on Adversarial Prompts](https://arxiv.org/abs/2306.04528).and the [Demo](https://huggingface.co/spaces/March07/PromptBench) site. ## News diff --git a/inference.py b/inference.py index ea8870f..97154b7 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import openai from config import LABEL_SET, LABEL_TO_ID from tqdm import tqdm @@ -216,53 +215,20 @@ def eval(self, preds, gts): raise NotImplementedError( "Eval this dataset {self.args.dataset} is not implemented!") - def predict(self, prompt=None): - assert self.args.data is not None, "Please load data first!" - - if self.model in ["chatgpt", "gpt4"]: - results = self.predict_by_openai_api(self.model, prompt) - else: - results = self.predict_by_local_inference(self.model, prompt) - return results - - def predict_by_openai_api(self, model, prompt): - data_len = len(self.args.data) - if data_len > 1000: - data_len = 1000 - - score = 0 - check_correctness = 100 - preds = [] - gts = [] - - for idx in tqdm(range(data_len)): - - raw_data = self.args.data.get_content_by_idx( - idx, self.args.dataset) - input_text, gt = self.process_input(prompt, raw_data) - - raw_pred = self.call_openai_api(model, input_text) - pred = self.process_pred(raw_pred) + def predict(self, prompt): + """Predict the final score (e.g., accuracy) for the input prompt using self.model - preds.append(pred) - gts.append(gt) - - if check_correctness > 0: - self.args.logger.info("gt: {}".format(gt)) - self.args.logger.info("Pred: {}".format(pred)) - self.args.logger.info("sentence: {}".format(input_text)) - - check_correctness -= 1 + Args: + prompt (str): prompt - score = self.eval(preds, gts) - return score - - - def predict_by_local_inference(self, model, prompt): + Returns: + float: score (e.g., accuracy) + """ + assert self.args.data is not None, "Please load data first!" + assert prompt is not None, "Please input prompt first!" data_len = len(self.args.data) - if data_len > 1000: - data_len = 1000 - + if self.args.max_sample > 0 and data_len > self.args.max_sample: + data_len = self.args.max_sample score = 0 check_correctness = 100 preds = [] @@ -273,8 +239,14 @@ def predict_by_local_inference(self, model, prompt): raw_data = self.args.data.get_content_by_idx( idx, self.args.dataset) input_text, gt = self.process_input(prompt, raw_data) - - raw_pred = self.pred_by_generation(input_text, model) + if self.model in ['chatgpt', 'gpt4']: + if (idx+1) % 40 == 0: # random sleep for every 40 requests + import time + import random + time.sleep(random.random() * 10 + 5) + raw_pred = self.call_openai_api(self.model, input_text) + else: + raw_pred = self.pred_by_generation(input_text, self.model) pred = self.process_pred(raw_pred) preds.append(pred) @@ -291,9 +263,12 @@ def predict_by_local_inference(self, model, prompt): return score def call_openai_api(self, model, prompt): + import random + import time + time.sleep(random.random() + 3) import openai from config import OPENAI_API - openai.api_key = OPENAI_API + openai.api_key = OPENAI_API() if model in ['chatgpt']: response = openai.Completion.create( model="gpt-3.5-turbo-instruct", diff --git a/main.py b/main.py index b9055f3..27d655a 100644 --- a/main.py +++ b/main.py @@ -8,9 +8,9 @@ from config import * from dataload import create_dataset from inference import Inference -from prompt_attack.attack import create_attack -from prompt_attack.goal_function import create_goal_function -from config import MODEL_SET +# from prompt_attack.attack import create_attack +# from prompt_attack.goal_function import create_goal_function +from config import MODEL_SET, DATA_SET, ATTACK_SET def create_logger(log_path): @@ -35,28 +35,11 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='google/flan-t5-large', choices=MODEL_SET) - parser.add_argument('--dataset', type=str, default='bool_logic', choices=["sst2", "cola", "qqp", - "mnli", "mnli_matched", "mnli_mismatched", - "qnli", "wnli", "rte", "mrpc", - "mmlu", "squad_v2", "un_multi", "iwslt", "math", - "bool_logic", "valid_parentheses", - ]) + parser.add_argument('--dataset', type=str, default='bool_logic', choices=DATA_SET) parser.add_argument('--query_budget', type=float, default=float("inf")) - parser.add_argument('--attack', type=str, default='deepwordbug', choices=[ - 'textfooler', - 'textbugger', - 'bertattack', - 'deepwordbug', - 'checklist', - 'stresstest', - 'semantic', - 'no', - 'noattack', - 'clean', - ]) + parser.add_argument('--attack', type=str, default='deepwordbug', choices=ATTACK_SET) parser.add_argument("--verbose", type=bool, default=True) - parser.add_argument('--output_dir', type=str, default='./') parser.add_argument('--model_dir', type=str, default="/home/v-kaijiezhu/") @@ -67,6 +50,11 @@ def get_args(): parser.add_argument('--prompt_selection', action='store_true') + # maximum samples allowed to predict due to high cost of some models + parser.add_argument('--max_sample', type=int, default=1000) + + parser.add_argument('--clean_attack', type=str, default='clean') + args = parser.parse_args() return args @@ -105,11 +93,13 @@ def attack(args, inference_model, RESULTS_DIR): language, acc*100, prompt)) elif args.attack in ['no', 'noattack', 'clean']: from config import PROMPT_SET_Promptbench_advglue as prompt_raw - prompt = prompt_raw['clean'][args.dataset][0] + prompt = prompt_raw[args.clean_attack][args.dataset][0] acc = inference_model.predict(prompt) - args.logger.info(f"Prompt: {prompt}, acc: {acc}%\n") + info = "Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100) + args.logger.info(info) + print(args, info) with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100)) + f.write(info) else: if args.shot == 0: from prompts.zero_shot.task_oriented import TASK_ORIENTED_PROMPT_SET