Skip to content

Commit

Permalink
add: code for running openai
Browse files Browse the repository at this point in the history
  • Loading branch information
jindongwang committed Oct 9, 2023
1 parent 23676ee commit 01095f0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 74 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PromptBench is a powerful tool designed to scrutinize and analyze the interactio



Check our paper: [PromptBench: Towards Evaluating the Robustness of Large Language Models on Adversarial Prompts](https://arxiv.org/abs/2306.04528).
Check our paper: [PromptBench: Towards Evaluating the Robustness of Large Language Models on Adversarial Prompts](https://arxiv.org/abs/2306.04528).and the [Demo](https://huggingface.co/spaces/March07/PromptBench) site.


## News
Expand Down
71 changes: 23 additions & 48 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import openai
from config import LABEL_SET, LABEL_TO_ID
from tqdm import tqdm

Expand Down Expand Up @@ -216,53 +215,20 @@ def eval(self, preds, gts):
raise NotImplementedError(
"Eval this dataset {self.args.dataset} is not implemented!")

def predict(self, prompt=None):
assert self.args.data is not None, "Please load data first!"

if self.model in ["chatgpt", "gpt4"]:
results = self.predict_by_openai_api(self.model, prompt)
else:
results = self.predict_by_local_inference(self.model, prompt)
return results

def predict_by_openai_api(self, model, prompt):
data_len = len(self.args.data)
if data_len > 1000:
data_len = 1000

score = 0
check_correctness = 100
preds = []
gts = []

for idx in tqdm(range(data_len)):

raw_data = self.args.data.get_content_by_idx(
idx, self.args.dataset)
input_text, gt = self.process_input(prompt, raw_data)

raw_pred = self.call_openai_api(model, input_text)
pred = self.process_pred(raw_pred)
def predict(self, prompt):
"""Predict the final score (e.g., accuracy) for the input prompt using self.model
preds.append(pred)
gts.append(gt)

if check_correctness > 0:
self.args.logger.info("gt: {}".format(gt))
self.args.logger.info("Pred: {}".format(pred))
self.args.logger.info("sentence: {}".format(input_text))

check_correctness -= 1
Args:
prompt (str): prompt
score = self.eval(preds, gts)
return score


def predict_by_local_inference(self, model, prompt):
Returns:
float: score (e.g., accuracy)
"""
assert self.args.data is not None, "Please load data first!"
assert prompt is not None, "Please input prompt first!"
data_len = len(self.args.data)
if data_len > 1000:
data_len = 1000

if self.args.max_sample > 0 and data_len > self.args.max_sample:
data_len = self.args.max_sample
score = 0
check_correctness = 100
preds = []
Expand All @@ -273,8 +239,14 @@ def predict_by_local_inference(self, model, prompt):
raw_data = self.args.data.get_content_by_idx(
idx, self.args.dataset)
input_text, gt = self.process_input(prompt, raw_data)

raw_pred = self.pred_by_generation(input_text, model)
if self.model in ['chatgpt', 'gpt4']:
if (idx+1) % 40 == 0: # random sleep for every 40 requests
import time
import random
time.sleep(random.random() * 10 + 5)
raw_pred = self.call_openai_api(self.model, input_text)
else:
raw_pred = self.pred_by_generation(input_text, self.model)
pred = self.process_pred(raw_pred)

preds.append(pred)
Expand All @@ -291,9 +263,12 @@ def predict_by_local_inference(self, model, prompt):
return score

def call_openai_api(self, model, prompt):
import random
import time
time.sleep(random.random() + 3)
import openai
from config import OPENAI_API
openai.api_key = OPENAI_API
openai.api_key = OPENAI_API()
if model in ['chatgpt']:
response = openai.Completion.create(
model="gpt-3.5-turbo-instruct",
Expand Down
40 changes: 15 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from config import *
from dataload import create_dataset
from inference import Inference
from prompt_attack.attack import create_attack
from prompt_attack.goal_function import create_goal_function
from config import MODEL_SET
# from prompt_attack.attack import create_attack
# from prompt_attack.goal_function import create_goal_function
from config import MODEL_SET, DATA_SET, ATTACK_SET


def create_logger(log_path):
Expand All @@ -35,28 +35,11 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str,
default='google/flan-t5-large', choices=MODEL_SET)
parser.add_argument('--dataset', type=str, default='bool_logic', choices=["sst2", "cola", "qqp",
"mnli", "mnli_matched", "mnli_mismatched",
"qnli", "wnli", "rte", "mrpc",
"mmlu", "squad_v2", "un_multi", "iwslt", "math",
"bool_logic", "valid_parentheses",
])
parser.add_argument('--dataset', type=str, default='bool_logic', choices=DATA_SET)

parser.add_argument('--query_budget', type=float, default=float("inf"))
parser.add_argument('--attack', type=str, default='deepwordbug', choices=[
'textfooler',
'textbugger',
'bertattack',
'deepwordbug',
'checklist',
'stresstest',
'semantic',
'no',
'noattack',
'clean',
])
parser.add_argument('--attack', type=str, default='deepwordbug', choices=ATTACK_SET)
parser.add_argument("--verbose", type=bool, default=True)

parser.add_argument('--output_dir', type=str, default='./')

parser.add_argument('--model_dir', type=str, default="/home/v-kaijiezhu/")
Expand All @@ -67,6 +50,11 @@ def get_args():

parser.add_argument('--prompt_selection', action='store_true')

# maximum samples allowed to predict due to high cost of some models
parser.add_argument('--max_sample', type=int, default=1000)

parser.add_argument('--clean_attack', type=str, default='clean')

args = parser.parse_args()
return args

Expand Down Expand Up @@ -105,11 +93,13 @@ def attack(args, inference_model, RESULTS_DIR):
language, acc*100, prompt))
elif args.attack in ['no', 'noattack', 'clean']:
from config import PROMPT_SET_Promptbench_advglue as prompt_raw
prompt = prompt_raw['clean'][args.dataset][0]
prompt = prompt_raw[args.clean_attack][args.dataset][0]
acc = inference_model.predict(prompt)
args.logger.info(f"Prompt: {prompt}, acc: {acc}%\n")
info = "Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100)
args.logger.info(info)
print(args, info)
with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f:
f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100))
f.write(info)
else:
if args.shot == 0:
from prompts.zero_shot.task_oriented import TASK_ORIENTED_PROMPT_SET
Expand Down

0 comments on commit 01095f0

Please sign in to comment.