Skip to content

Commit

Permalink
Merge pull request #10 from morganmcg1/add_wandb_to_classification
Browse files Browse the repository at this point in the history
Add wandb to the Classification Notebook
  • Loading branch information
mshumer authored Jul 11, 2023
2 parents 9d01515 + a9be4eb commit 1b89611
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 66 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ Prompt engineering is kind of like alchemy. There's no clear way to predict what
- **Classification Version**: The `gpt-prompt-engineer -- Classification Version` notebook is designed to handle classification tasks. It evaluates the correctness of a test case by matching it to the expected output ('true' or 'false') and provides a table with scores for each prompt.
<img width="1607" alt="Screen Shot 2023-07-10 at 5 22 24 PM" src="https://github.com/mshumer/gpt-prompt-engineer/assets/41550495/d5c9f2a8-97fa-445d-9c38-dec744f77854">

- **Weights & Biases Logging**: Optional logging to [Weights & Biases](https://wandb.ai/site) of your configs such as temperature and max tokens, the system and user prompts for each part, the test cases used and the final ranked ELO rating for each candidate prompt. Set `use_wandb` to `True` to use. Only available in the main `gpt-prompt-engineer` notebook for now.

- **[Weights & Biases](https://wandb.ai/site/prompts) Logging**: Optional logging to [Weights & Biases](https://wandb.ai/site) of your configs such as temperature and max tokens, the system and user prompts for each part, the test cases used and the final ranked ELO rating for each candidate prompt. Set `use_wandb` to `True` to use.

## Setup
1. [Open the notebook in Google Colab](https://colab.research.google.com/github/mshumer/gpt-prompt-engineer/blob/main/gpt_prompt_engineer.ipynb) or in a local Jupyter notebook. For classification, use [this one.](https://colab.research.google.com/drive/16NLMjqyuUWxcokE_NF6RwHD8grwEeoaJ?usp=sharing)
Expand Down
250 changes: 185 additions & 65 deletions gpt_prompt_engineer_Classification_Version.ipynb
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMvbQztC95mJY9x+Gc/uEm+",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/mshumer/gpt-prompt-engineer/blob/main/gpt_prompt_engineer_Classification_Version.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "L0Ey7JZ5iLo1"
},
"source": [
"# gpt-prompt-engineer -- Classification Version\n",
"By Matt Shumer (https://twitter.com/mattshumer_)\n",
Expand All @@ -40,85 +29,171 @@
"\n",
"To generate a prompt:\n",
"1. In the first cell, add in your OpenAI key.\n",
"2. If you don't have GPT-4 access, change `model='gpt-4'` in the second cell to `model='gpt-3.5-turbo'`. If you do have access, skip this step.\n",
"2. If you don't have GPT-4 access, change `CANDIDATE_MODEL='gpt-4'` in the second cell to `CANDIDATE_MODEL='gpt-3.5-turbo'`. If you do have access, skip this step.\n",
"2. In the last cell, fill in the description of your task, as many test cases as you want (test cases are example prompts and their expected output), and the number of prompts to generate.\n",
"3. Run all the cells! The AI will generate a number of candidate prompts, and test them all to find the best one!"
],
"metadata": {
"id": "L0Ey7JZ5iLo1"
}
"3. Run all the cells! The AI will generate a number of candidate prompts, and test them all to find the best one!\n",
"\n",
"🪄🐝 To use [Weights & Biases logging](https://wandb.ai/site/prompts) to your LLM configs and the generated prompt outputs, just set `use_wandb = True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install openai prettytable tqdm tenacity wandb -qq"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "UW3ztLRsolnk"
},
"outputs": [],
"source": [
"!pip install openai\n",
"!pip install prettytable\n",
"\n",
"from prettytable import PrettyTable\n",
"import time\n",
"import wandb\n",
"import openai\n",
"from tenacity import retry, stop_after_attempt, wait_exponential\n",
"\n",
"openai.api_key = \"ADD YOUR KEY HERE\" # enter your OpenAI API key here"
"openai.api_key = \"ADD YOUR KEY HERE\" # enter your OpenAI API key here\n",
"\n",
"use_wandb = True # set to True if you want to use wandb to log your config and results"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def generate_candidate_prompts(description, test_cases, number_of_prompts):\n",
" outputs = openai.ChatCompletion.create(\n",
" model='gpt-4',\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"\"\"Your job is to generate system prompts for GPT-4, given a description of the use-case and some test cases.\n",
"candidate_gen_system_prompt = \"\"\"Your job is to generate system prompts for GPT-4, given a description of the use-case and some test cases.\n",
"\n",
"The prompts you will be generating will be for classifiers, with 'true' and 'false' being the only possible outputs.\n",
"\n",
"In your generated prompt, you should describe how the AI should behave in plain English. Include what it will see, and what it's allowed to output. Be creative in with prompts to get the best possible results. The AI knows it's an AI -- you don't need to tell it this.\n",
"\n",
"You will be graded based on the performance of your prompt... but don't cheat! You cannot include specifics about the test cases in your prompt. Any prompts with examples will be disqualified.\n",
"\n",
"Most importantly, output NOTHING but the prompt. Do not include anything else in your message.\"\"\"},\n",
"Most importantly, output NOTHING but the prompt. Do not include anything else in your message.\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"CANDIDATE_MODEL = 'gpt-4'\n",
"CANDIDATE_MODEL_TEMPERATURE = 0.9\n",
"\n",
"EVAL_MODEL = 'gpt-3.5-turbo'\n",
"EVAL_MODEL_TEMPERATURE = 0\n",
"EVAL_MODEL_MAX_TOKENS = 1\n",
"\n",
"NUMBER_OF_PROMPTS = 10 # this determines how many candidate prompts to generate... the higher, the more expensive\n",
"\n",
"N_RETRIES = 3 # number of times to retry a call to the ranking model if it fails\n",
"\n",
"WANDB_PROJECT_NAME = \"gpt-prompt-eng\" # used if use_wandb is True, Weights &| Biases project name\n",
"WANDB_RUN_NAME = None # used if use_wandb is True, optionally set the Weights & Biases run name to identify this run"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"def start_wandb_run():\n",
" # start a new wandb run and log the config\n",
" wandb.init(\n",
" project=WANDB_PROJECT_NAME, \n",
" name=WANDB_RUN_NAME,\n",
" config={\n",
" \"candidate_gen_system_prompt\": candidate_gen_system_prompt, \n",
" \"candiate_model\": CANDIDATE_MODEL,\n",
" \"candidate_model_temperature\": CANDIDATE_MODEL_TEMPERATURE,\n",
" \"generation_model\": EVAL_MODEL,\n",
" \"generation_model_temperature\": EVAL_MODEL_TEMPERATURE,\n",
" \"generation_model_max_tokens\": EVAL_MODEL_MAX_TOKENS,\n",
" \"n_retries\": N_RETRIES,\n",
" \"number_of_prompts\": NUMBER_OF_PROMPTS\n",
" })\n",
" \n",
" return "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional logging to Weights & Biases to reocrd the configs, prompts and results\n",
"if use_wandb:\n",
" start_wandb_run()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "KTRFiBhSouz8"
},
"outputs": [],
"source": [
"# Get Score - retry up to N_RETRIES times, waiting exponentially between retries.\n",
"@retry(stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=70))\n",
"def generate_candidate_prompts(description, test_cases, number_of_prompts):\n",
" outputs = openai.ChatCompletion.create(\n",
" model=CANDIDATE_MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": candidate_gen_system_prompt},\n",
" {\"role\": \"user\", \"content\": f\"Here are some test cases:`{test_cases}`\\n\\nHere is the description of the use-case: `{description.strip()}`\\n\\nRespond with your prompt, and nothing else. Be creative.\"}\n",
" ],\n",
" temperature=.9,\n",
" temperature=CANDIDATE_MODEL_TEMPERATURE,\n",
" n=number_of_prompts)\n",
"\n",
" prompts = []\n",
"\n",
" for i in outputs.choices:\n",
" prompts.append(i.message.content)\n",
" return prompts"
],
"metadata": {
"id": "KTRFiBhSouz8"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"id": "w4ltgxntszwK"
},
"outputs": [],
"source": [
"def test_candidate_prompts(test_cases, prompts):\n",
" prompt_results = {prompt: {'correct': 0, 'total': 0} for prompt in prompts}\n",
"\n",
" # Initialize the table\n",
" table = PrettyTable()\n",
" table.field_names = [\"Prompt\", \"Expected\"] + [f\"Prompt {i+1}-{j+1}\" for j, prompt in enumerate(prompts) for i in range(prompts.count(prompt))]\n",
"\n",
" table_field_names = [\"Prompt\", \"Expected\"] + [f\"Prompt {i+1}-{j+1}\" for j, prompt in enumerate(prompts) for i in range(prompts.count(prompt))]\n",
" table.field_names = table_field_names\n",
"\n",
" # Wrap the text in the \"Prompt\" column\n",
" table.max_width[\"Prompt\"] = 100\n",
"\n",
" if use_wandb:\n",
" wandb_table = wandb.Table(columns=table_field_names)\n",
" if wandb.run is None:\n",
" start_wandb_run()\n",
"\n",
" for test_case in test_cases:\n",
" row = [test_case['prompt'], test_case['answer']]\n",
" for prompt in prompts:\n",
" x = openai.ChatCompletion.create(\n",
" model='gpt-3.5-turbo',\n",
" model=EVAL_MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": prompt},\n",
" {\"role\": \"user\", \"content\": f\"{test_case['prompt']}\"}\n",
Expand All @@ -127,8 +202,8 @@
" '1904': 100, # 'true' token\n",
" '3934': 100, # 'false' token\n",
" },\n",
" max_tokens=1,\n",
" temperature=0,\n",
" max_tokens=EVAL_MODEL_MAX_TOKENS,\n",
" temperature=EVAL_MODEL_TEMPERATURE,\n",
" ).choices[0].message.content\n",
"\n",
"\n",
Expand All @@ -141,31 +216,46 @@
" prompt_results[prompt]['total'] += 1\n",
"\n",
" table.add_row(row)\n",
" if use_wandb:\n",
" wandb_table.add_data(*row)\n",
"\n",
" print(table)\n",
"\n",
" # Calculate and print the percentage of correct answers and average time for each model\n",
" best_prompt = None\n",
" best_percentage = 0\n",
" if use_wandb:\n",
" prompts_results_table = wandb.Table(columns=[\"Prompt Number\", \"Prompt\", \"Percentage\", \"Correct\", \"Total\"])\n",
" \n",
" for i, prompt in enumerate(prompts):\n",
" correct = prompt_results[prompt]['correct']\n",
" total = prompt_results[prompt]['total']\n",
" percentage = (correct / total) * 100\n",
" print(f\"Prompt {i+1} got {percentage:.2f}% correct.\")\n",
" if use_wandb:\n",
" prompts_results_table.add_data(i, prompt, percentage, correct, total)\n",
" if percentage > best_percentage:\n",
" best_percentage = percentage\n",
" best_prompt = prompt\n",
"\n",
" if use_wandb: # log the results to a Weights & Biases table and finsih the run\n",
" wandb.log({\"prompt_results\": prompts_results_table})\n",
" best_prompt_table = wandb.Table(columns=[\"Best Prompt\", \"Best Percentage\"])\n",
" best_prompt_table.add_data(best_prompt, best_percentage)\n",
" wandb.log({\"best_prompt\": best_prompt_table})\n",
" wandb.log({\"prompt_ratings\": wandb_table})\n",
" wandb.finish()\n",
"\n",
" print(f\"The best prompt was '{best_prompt}' with a correctness of {best_percentage:.2f}%.\")"
],
"metadata": {
"id": "w4ltgxntszwK"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"id": "SBJEi1hkrT9T"
},
"outputs": [],
"source": [
"test_cases = [\n",
" {\n",
Expand Down Expand Up @@ -275,22 +365,52 @@
" {\n",
" 'prompt': 'Plan a surprise birthday party for my best friend.',\n",
" 'answer': 'false'\n",
" }]\n",
"\n",
"\n",
" }]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"description = \"Decide if a task is research-heavy.\" # describe the classification task clearly\n",
"number_of_prompts = 10 # choose how many prompts you want to generate and test\n",
"\n",
"\n",
"# If Weights & Biases is enabled, log the description and test cases too\n",
"if use_wandb:\n",
" if wandb.run is None:\n",
" start_wandb_run()\n",
" wandb.config.update({\"description\": description, \n",
" \"test_cases\": test_cases})\n",
"\n",
"candidate_prompts = generate_candidate_prompts(description, test_cases, number_of_prompts)\n",
"candidate_prompts = generate_candidate_prompts(description, test_cases, NUMBER_OF_PROMPTS)\n",
"test_candidate_prompts(test_cases, candidate_prompts)"
],
"metadata": {
"id": "SBJEi1hkrT9T"
]
}
],
"metadata": {
"colab": {
"authorship_tag": "ABX9TyMvbQztC95mJY9x+Gc/uEm+",
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"execution_count": null,
"outputs": []
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
]
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 1b89611

Please sign in to comment.