From f85784dff3b0d66b62fcd14f552e278c42b68709 Mon Sep 17 00:00:00 2001 From: Zack Witten Date: Mon, 21 Aug 2023 20:34:46 -0700 Subject: [PATCH] with results table --- long_context/mc_qa.ipynb | 731 +++++++++++++++++++++++++++++++++------ 1 file changed, 629 insertions(+), 102 deletions(-) diff --git a/long_context/mc_qa.ipynb b/long_context/mc_qa.ipynb index 622e100..32381d4 100644 --- a/long_context/mc_qa.ipynb +++ b/long_context/mc_qa.ipynb @@ -38,7 +38,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import anthropic, os, re, requests, trio, pandas as pd\n", + "from bs4 import BeautifulSoup\n", + "API_KEY = os.environ['ANTHROPIC_API_KEY']\n", + "CLIENT = anthropic.Anthropic(api_key=API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -110,9 +122,6 @@ } ], "source": [ - "import anthropic, os, re, requests, tiktoken, trio, pandas as pd, numpy as np\n", - "from bs4 import BeautifulSoup\n", - "\n", "url = 'https://www.govinfo.gov/content/pkg/FR-2023-07-13/xml/FR-2023-07-13.xml'\n", "\n", "response = requests.get(url)\n", @@ -123,8 +132,7 @@ "chunks[0] = chunks[0][chunks[0].index('DEPARTMENT OF TRANSPORTATION'):] # First chunk has some extra material at the beginning.\n", "\n", "# We'll throw out the chunks that are extra-long or extra-short.\n", - "import tiktoken\n", - "tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n", + "tokenizer = CLIENT.get_tokenizer()\n", "chunks = [c for c in chunks if len(tokenizer.encode(c)) <= 5000 and len(tokenizer.encode(c)) > 200]\n", "print(len(chunks))\n", "print(chunks[2])" @@ -146,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -289,9 +297,11 @@ "{test_passage}\n", "\n", "\n", - "Please write five factual questions about it. For each question, give three wrong answers and the right answer.\n", - "Always put the correct answer first. Write 4 non-numerical questions and one numerical one. Make sure the wrong answers are highly detailed.\n", - "Put the question inside tags, and the answers inside tags, where N is the index of the question, as in the examples.\n", + "Please write five factual questions about this document that can be answered with reference to it and without any outside knowledge. For each question, give three wrong answers and the right answer. Always put the correct answer first. Write 4 non-numerical questions and one numerical one. Make sure the wrong answers are highly detailed. Put the question inside tags, and the answers inside tags, where N is the index of the question, as in the examples. \n", + "\n", + "Guidelines:\n", + "Make sure that each question clearly and independently identifies the section/minutes/government meeting from which it derives; avoid terms like \"this document\", \"this passage\", \"this notice\" in favor of more specific descriptions. The goal is to future-proof the questions and answers in the event that they became divorced from their subject in the filing system.\n", + "Make the questions specific to their source text. Eschew generic questions about date of publication or name of agency. Instead, prefer questions that could not apply to notes produced by any other department/agency.\n", "\n", "Assistant:\n", "\"\"\"" @@ -301,14 +311,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A key detail to pay attention to in the prompt above: the instruction to make the wrong answers \"highly detailed\". Without this instruction, the wrong answers tended to be relatively short and the right answer stood out on length alone.\n", + "A key detail to pay attention to in the prompt above: the instruction to make the wrong answers \"highly detailed\". Without this instruction, the wrong answers tended to be relatively short and the right answer stood out on length alone. Put a pin in the instruction to \"Make sure that each question clearly and independently identifies the section/minutes/government meeting from which it derives\"; we'll come back to it later.\n", "\n", "Now, we'll make a dataframe with a column where we fill in the prompt template for each chunk, excluding the two chunks we used in the two-shot." ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -334,25 +344,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this notebook, we'll use Claude Instant, which has 100K context window just like Claude 2. You can also run it with Claude 2 to similar results. Tip: get faster output by using the trio library to parallelize these calls." + "In this notebook, we'll use Claude Instant, which has a 100K context window just like Claude 2. You can also run it with Claude 2 to similar results." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "API_KEY = os.environ['ANTHROPIC_API_KEY']\n", - "CLIENT = anthropic.Anthropic(api_key=API_KEY)\n", - "\n", - "pd.set_option('mode.chained_assignment', None)\n", - "def get_completion(client, prompt, max_tokens=3000, model='claude-instant-1.1', temperature=0):\n", + "def get_completion(client, prompt, max_tokens=3000, model='claude-instant-1.2', temperature=0):\n", " return client.completions.create(\n", " prompt=prompt, max_tokens_to_sample=max_tokens, model=model, temperature=temperature\n", " ).completion\n", "\n", - "df['qas'] = df.prompt.apply(lambda prompt: get_completion(CLIENT, prompt))" + "async def process_case(limiter, client, prompt, results, output_col_name='completion'):\n", + "\n", + " async with limiter:\n", + " completion = await trio.to_thread.run_sync(get_completion, client, prompt)\n", + "\n", + " results.append({'prompt': prompt, output_col_name: completion})\n", + "\n", + " # if len(results) % 5 == 0:\n", + " # print(f\"{len(results)} test cases processed\") # Optional \"progress bar\"\n", + "\n", + "async def get_completions_parallel(client, prompts, output_col_name='completion'):\n", + " async with trio.open_nursery() as nursery:\n", + " limiter = trio.CapacityLimiter(5) # Set this to the maximum concurrency allowed on your API key, which may just be 1.\n", + " results = []\n", + " for prompt in prompts:\n", + " nursery.start_soon(process_case, limiter, CLIENT, prompt, results, output_col_name)\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qas = await get_completions_parallel(CLIENT, df.prompt.values, output_col_name='qas')\n", + "df = df.merge(pd.DataFrame(qas), on='prompt')" ] }, { @@ -367,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -377,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -411,19 +443,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We started out with 86 sections after devoting 2 of the original 88 to examples. That should yield 86 * 5 = 430 questions, but there were 12 formatting errors, so we end up with 418 after exploding." + "We started out with 86 sections after devoting 2 of the original 88 to examples, yielding 86 * 5 = 430 questions." ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "418\n" + "430\n" ] } ], @@ -449,33 +481,48 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Who can grant permission for vessels to enter or operate within the safety zone?\n", - "The Captain of the Port Detroit or his designated representative.\n", - "['The on-scene Coast Guard patrol commander.', 'The Chesterfield Harbor Master.', 'Vessels do not need permission, but must comply with directions from the Coast Guard.']\n", - "How many feet is 200 yards?\n", - "600 feet\n", - "['800 feet', '1200 feet', '1700 feet']\n", - "What agency published the final rule on July 3, 2023?\n", - "Environmental Protection Agency\n", - "['Department of Transportation', 'Department of Energy', 'Federal Aviation Administration']\n", - "What is the EPA taking final action to approve?\n", - "The \"Coso Junction PM10 Planning Area Second 10-Year Maintenance Plan\"\n", - "['New regulations for particulate matter emissions from off-road vehicles', \"A revision to California's State Implementation Plan for ozone standards\", 'Funding for air quality monitoring equipment in the Coso Junction Planning Area']\n", - "What does the EPA find the contribution of motor vehicle emissions to be in the Coso Junction Planning Area?\n", - "Insignificant\n", - "['Significant', 'Undetermined', 'Substantial']\n" + "What is the URL provided in the passage where users can access the FCC's fee payment module through the FRN access page after the closure of Lockbox 979097?\n", + "https://apps.fcc.gov/cores/paymentFrnLogin.do\n", + "['https://www.fcc.gov/licensing-databases/fees', 'https://www.fcc.gov/wireless-fees', 'https://www.fcc.gov/electronic-filing']\n", + "What is the total number of parts of the Code of Federal Regulations that are amended by the changes outlined in the passage?\n", + "3\n", + "['5', '2', '1']\n", + "What agency published this document announcing ten inseason actions for the 2023-2024 ocean salmon fishing season?\n", + "National Marine Fisheries Service (NMFS), National Oceanic and Atmospheric Administration (NOAA), Commerce.\n", + "['Federal Aviation Administration', 'Department of Defense', 'National Oceanic and Atmospheric Administration']\n", + "What is the effective date range for Inseason Action #1 described in this document?\n", + "May 17, 2023, at 12:01 a.m. until superseded.\n", + "['May 15, 2023 to June 29, 2023', 'June 8, 2023 to July 12, 2023', 'May 11, 2023 to August 14, 2023']\n", + "What is given as the reason for Inseason Action #5 increasing the landing and possession limit for the area between Leadbetter Point and Cape Falcon?\n", + "Lower than anticipated catch rates in the area\n", + "['New abundance estimates indicated higher quotas were sustainable', 'Concerns about socioeconomic impacts on local fishing communities', 'A court order requiring more liberal bag limits']\n", + "What is the contact provided in the document for further information about the inseason actions described in it?\n", + "Shannon Penna, 562–980–4239, Shannon.Penna@noaa.gov\n", + "['Todd Richardson, 202-402-5706', 'Daniela Cruzado, 202-295-7589', 'Nanette Smith, Team Lead, NASA Directives and Regulations']\n", + "How many inseason actions for the fisheries are described in this document?\n", + "10\n", + "['5', '15', '8']\n", + "What is the total allowable catch (ITAC) for blackspotted and rougheye rockfish set for the Central Aleutian and Western Aleutian districts (CAI/WAI) of the Bering Sea and Aleutian Islands management area (BSAI) in 2023 according to this document?\n", + "141 metric tons\n", + "['187 metric tons', '101 metric tons']\n", + "According to this document, as of what date had the 2023 blackspotted and rougheye rockfish ITAC for the CAI/WAI been reached?\n", + "July 7, 2023\n", + "['June 30, 2023', 'August 15, 2023']\n", + "What agency published this temporary rule concerning blackspotted and rougheye rockfish in the Federal Register?\n", + "National Marine Fisheries Service (NMFS), National Oceanic and Atmospheric Administration (NOAA), Commerce\n", + "['Environmental Protection Agency', 'Department of the Interior']\n" ] } ], "source": [ - "for i in range(18, 23):\n", + "for i in range(28, 38):\n", " for c in ['question', 'right_answer', 'wrong_answers_q']:\n", " print(qa_df.iloc[i][c])" ] @@ -484,7 +531,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The 200-yard question doesn't require any particular context about the document to answer. Claude should just know it off the top of it's head. We should take some steps to validate that there are relatively few questions like this and the test is still \"hard enough\". The other questions look good." + "Looks reasonable!" ] }, { @@ -503,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -533,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -567,21 +614,32 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "qa_df['qa_with_right_chunk_prompt'] = qa_df.apply(lambda row: mc_answer_one_chunk_prompt.format(\n", " chunk=row['chunk'], question=row['question'], answers=row['randomized_answers']),\n", " axis=1\n", - ") # Populate prompt column\n", - "\n", - "qa_df['qa_answer_right_chunk'] = qa_df.apply(\n", - " lambda row: get_completion(\n", - " CLIENT, mc_answer_one_chunk_prompt.format(chunk=row['chunk'], question=row['question'], answers=row['randomized_answers'])\n", - " ),\n", - " axis=1\n", - ") # Call Claude on each chunk + question + answer group pair." + ") # Populate prompt column" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answer_right_chunk = await get_completions_parallel(CLIENT, qa_df['qa_with_right_chunk_prompt'].values, output_col_name='qa_answer_right_chunk')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "qa_df = qa_df.merge(pd.DataFrame(qa_answer_right_chunk), left_on='qa_with_right_chunk_prompt', right_on='prompt', suffixes=['', '_x'])" ] }, { @@ -593,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -611,27 +669,19 @@ }, { "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "qa_df['qa_answer_right_chunk'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answer_right_chunk'].values]" - ] - }, - { - "cell_type": "code", - "execution_count": 32, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Results: 371 47\n" + "Results: 389 41\n" ] } ], "source": [ + "qa_df['qa_answer_right_chunk'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answer_right_chunk'].values]\n", "print_results(qa_df, qa_df['qa_answer_right_chunk'])" ] }, @@ -639,14 +689,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It got most of them right. That's good as it shows the questions are mostly well-formed. It's interesting to look at some examples Claude got wrong; they tend to be about counting the number of items in long lists, which is not one of Claude's specialties. Better question prompting could remove these.\n", - "\n", - "Now, we'll see how Claude does when, instead of giving Claude the chunk with the answer, we give it some random other chunk. Poor Claude!" + "It got 90% of them right. Now, we'll see how Claude does when, instead of giving Claude the chunk with the answer, we give it some random other chunk. Poor Claude!" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -657,27 +705,44 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "qa_df['qa_answer_shift_chunk'] = qa_df.apply(\n", - " lambda row: get_completion(\n", - " CLIENT, mc_answer_one_chunk_prompt.format(chunk=row['shifted_chunk'], question=row['question'], answers=row['randomized_answers'])\n", - " ),\n", - " axis=1)" + "qa_df['qa_with_shift_chunk_prompt'] = qa_df.apply(\n", + " lambda row: mc_answer_one_chunk_prompt.format(chunk=row['shifted_chunk'], question=row['question'], answers=row['randomized_answers']),\n", + " axis=1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answer_shift_chunk = await get_completions_parallel(CLIENT, qa_df['qa_with_shift_chunk_prompt'].values, output_col_name='qa_answer_shift_chunk')" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "qa_df = qa_df.merge(pd.DataFrame(qa_answer_shift_chunk), left_on='qa_with_shift_chunk_prompt', right_on='prompt', suffixes=['', '_x'])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Results: 148 270 \n" + "Results: 148 282\n" ] } ], @@ -690,21 +755,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "By sheer chance Claude would be expected to get 25% right. In practice, Claude got 35%. Just as smart humans like us have the ability to guess above chance on a standardized test, so does Claude. Still a far cry from Claude's accuracy when given the right chunk, so the experiment is meaningful. We'll filter out the questions where Claude didn't get the correct answer even with the relevant chunk." + "By sheer chance Claude would be expected to get 25% right. In practice, Claude got 34% right. Just as smart humans like us have the ability to guess above chance on a standardized test, so does Claude. Still a far cry from Claude's accuracy when given the right chunk, so the experiment is meaningful. We'll filter out the questions where Claude didn't get the correct answer even with the relevant chunk, as those are \"too difficult\" for testing the impact of long context." ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "371" + "389" ] }, - "execution_count": 41, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -730,7 +795,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -757,7 +822,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -768,7 +833,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -794,7 +859,147 @@ "source": [ "Now we'll do another round of sampling for beginning, middle, and end. \n", "\n", - "*Note: Each of these cells takes a while to run.* Parallelization can speed things up. But if you're just following along for fun, you probably want to run this only on a few rows of qa_df." + "*Note: Each of these cells takes a while to run.* If you're just following along for fun, you probably want to run this only on a few rows of qa_df." + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 375 14\n" + ] + } + ], + "source": [ + "qa_answers_long_ctx_end = await get_completions_parallel(CLIENT, qa_df.qa_long_ctx_prompt_end.values, output_col_name='qa_answers_long_ctx_end')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_end), left_on='qa_long_ctx_prompt_end', right_on='prompt', suffixes=['', '_x'], how='left')\n", + "qa_df['qa_answers_long_ctx_end'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_end'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_end'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 333 56\n" + ] + } + ], + "source": [ + "qa_answers_long_ctx_middle = await get_completions_parallel(CLIENT, qa_df.qa_long_ctx_prompt_middle.values, output_col_name='qa_answers_long_ctx_middle')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_middle), left_on='qa_long_ctx_prompt_middle', right_on='prompt', suffixes=['', '_x'], how='left')\n", + "qa_df['qa_answers_long_ctx_middle'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_middle'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_middle'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 272 117\n" + ] + } + ], + "source": [ + "qa_answers_long_ctx_beginning = await get_completions_parallel(CLIENT, qa_df.qa_long_ctx_prompt_beginning.values, output_col_name='qa_answers_long_ctx_beginning')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_beginning), left_on='qa_long_ctx_prompt_beginning', right_on='prompt', suffixes=['', '_x'], how='left')\n", + "qa_df['qa_answers_long_ctx_beginning'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_beginning'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_beginning'].values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll try adding some examples of multiple-choice question-answering to the doc, using some made-up examples." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "mc_answer_lc_with_nongov_examples_prompt = \"\"\"\\n\\nHuman: Please read the following government record closely and then answer the multiple choice question below.\n", + "\n", + "{chunk}\n", + "\n", + "Based on the government record above, select the correct answer to the question from the list below and write the corresponding letter (A, B, C, or D) in tags.\n", + "First, here are two example questions.\n", + "\n", + "Who was the first president of the United States?\n", + "\n", + "\n", + "A. Thomas Jefferson\n", + "B. George Washington\n", + "C. Abraham Lincoln\n", + "D. John Adams\n", + "\n", + "Here, the correct answer is:\n", + "\n", + "B. George Washington\n", + "\n", + "\n", + "What is the boiling temperature of water, in degrees Fahrenheit?\n", + "\n", + "\n", + "A. 200\n", + "B. 100\n", + "C. 287\n", + "D. 212\n", + "\n", + "Here, the correct answer is:\n", + "\n", + "D. 212\n", + "\n", + "Now please answer this question\n", + "\n", + "{question}\n", + "\n", + "Based on the government record above, select the correct answer to the question from the list below and write the corresponding letter (A, B, C, or D) in tags.\n", + "\n", + "{answers}\n", + "\n", + "\n", + "Assistant: Based on the government record provided above, the correct answer to the question is:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "qa_df['qa_long_ctx_prompt_end_nongov_examples'] = qa_df.apply(lambda row: mc_answer_lc_with_nongov_examples_prompt.format(\n", + " chunk=row['long_context_end'], question=row['question'], answers=row['randomized_answers']),\n", + " axis=1\n", + ")\n", + "\n", + "qa_df['qa_long_ctx_prompt_middle_nongov_examples'] = qa_df.apply(lambda row: mc_answer_lc_with_nongov_examples_prompt.format(\n", + " chunk=row['long_context_middle'], question=row['question'], answers=row['randomized_answers']),\n", + " axis=1\n", + ")\n", + "\n", + "qa_df['qa_long_ctx_prompt_beginning_nongov_examples'] = qa_df.apply(lambda row: mc_answer_lc_with_nongov_examples_prompt.format(\n", + " chunk=row['long_context_beginning'], question=row['question'], answers=row['randomized_answers']),\n", + " axis=1\n", + ")" ] }, { @@ -803,7 +1008,29 @@ "metadata": {}, "outputs": [], "source": [ - "qa_df['qa_answers_long_ctx_end'] = qa_df.qa_long_ctx_prompt_end.apply(lambda prompt: get_completion(CLIENT, prompt),)" + "qa_answers_long_ctx_end_nongov_examples = await get_completions_parallel(\n", + " CLIENT, qa_df['qa_long_ctx_prompt_end_nongov_examples'].values, output_col_name='qa_long_ctx_answers_end_nongov_examples')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_end_nongov_examples), left_on='qa_long_ctx_prompt_end_nongov_examples', right_on='prompt', suffixes=['', '_n'], how='left')" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 375 14\n" + ] + } + ], + "source": [ + "qa_df['qa_answers_long_ctx_end_nongov_examples_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df['qa_long_ctx_answers_end_nongov_examples'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_end_nongov_examples_parsed'].values)" ] }, { @@ -812,7 +1039,29 @@ "metadata": {}, "outputs": [], "source": [ - "qa_df['qa_answers_long_ctx_middle'] = qa_df.qa_long_ctx_prompt_middle.apply(lambda prompt: get_completion(CLIENT, prompt),)" + "qa_answers_long_ctx_middle_nongov_examples = await get_completions_parallel(\n", + " CLIENT, qa_df['qa_long_ctx_prompt_middle_nongov_examples'].values, output_col_name='qa_answers_long_ctx_middle_nongov_examples')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_middle_nongov_examples), left_on=f'qa_long_ctx_prompt_middle_nongov_examples', right_on='prompt', suffixes=['_m', '_n'], how='left')" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 338 51\n" + ] + } + ], + "source": [ + "qa_df['qa_answers_long_ctx_middle_nongov_examples_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df['qa_answers_long_ctx_middle_nongov_examples'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_middle_nongov_examples_parsed'].values)" ] }, { @@ -821,53 +1070,331 @@ "metadata": {}, "outputs": [], "source": [ - "qa_df['qa_answers_long_ctx_beginning'] = qa_df.qa_long_ctx_prompt_beginning.apply(lambda prompt: get_completion(CLIENT, prompt),)" + "qa_answers_long_ctx_beginning_nongov_examples = await get_completions_parallel(\n", + " CLIENT, qa_df['qa_long_ctx_prompt_beginning_nongov_examples'].values, output_col_name='qa_answers_long_ctx_beginning_nongov_examples')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_beginning_nongov_examples), left_on=f'qa_long_ctx_prompt_beginning_nongov_examples', right_on='prompt', suffixes=['_m', '_n'], how='left')" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 281 108\n" + ] + } + ], + "source": [ + "qa_df['qa_answers_long_ctx_beginning_nongov_examples_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df['qa_answers_long_ctx_beginning_nongov_examples'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_beginning_nongov_examples_parsed'].values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And voila, the results!" + "The results are a little better for the middle and beginning tests, equal for the end test. Can we do better by adding examples that are more germane to the task? \n", + "\n", + "The procedure for generating examples is as follows. For each question, find its associated chunk, then choose random QAs from other chunks that aren't that chunk." ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "def gen_mc_answer_lc_with_examples_prompt(num_examples): \n", + " examples_section = \"some example questions.\"\n", + " for i in range(num_examples):\n", + " examples_section += \"\"\"\n", + "\n", + "{sample_question\"\"\" + str(i+1) + \"\"\"}\n", + "\n", + "\n", + "{sample_answers\"\"\" + str(i+1) + \"\"\"}\n", + "\n", + "Here, the correct answer is:\n", + "\n", + "{correct_answer\"\"\" + str(i+1) + \"\"\"}\n", + "\"\"\"\n", + " return \"\"\"\\n\\nHuman: Please read the following government record closely and then answer the multiple choice question below.\n", + "\n", + "{chunk}\n", + "\n", + "Based on the government record above, select the correct answer to the question from the list below and write the corresponding letter (A, B, C, or D) in tags.\n", + "First, here are \"\"\" + examples_section + \"\"\"\n", + "Now please answer this question\n", + "\n", + "{question}\n", + "\n", + "Based on the government record above, select the correct answer to the question from the list below and write the corresponding letter (A, B, C, or D) in tags.\n", + "\n", + "{answers}\n", + "\n", + "\n", + "Assistant: Based on the government record provided above, the correct answer to the question is:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "def grab_example_qas(long_context_row, long_context_col, qa_df, num_examples=2):\n", + " examples = []\n", + " for i, row in qa_df.sample(frac=1).iterrows(): # Randomize order of questions\n", + " if row['chunk'] in long_context_row[long_context_col] and row['chunk'] != long_context_row.chunk:\n", + " examples.append({\n", + " 'question': row.question, 'answers': row.randomized_answers, \n", + " 'correct_answer': [a for a in row.randomized_answers if row.right_answer in a][0][0]})\n", + " if len(examples) >= num_examples:\n", + " break\n", + " examples_numbered = {}\n", + " for i in range(num_examples):\n", + " examples_numbered['sample_question' + str(i+1)] = examples[i]['question']\n", + " examples_numbered['sample_answers' + str(i+1)] = examples[i]['answers']\n", + " examples_numbered['correct_answer' + str(i+1)] = examples[i]['correct_answer']\n", + " return examples_numbered" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "def format_for_long_ctx_with_examples(row, chunk_col, long_context_col, qa_df, num_examples=2):\n", + " example_qas = grab_example_qas(long_context_row=row, long_context_col=long_context_col, qa_df=qa_df, num_examples=num_examples)\n", + " format_args = {}\n", + " for i in range(1, num_examples+1):\n", + " format_args['sample_question'+str(i)] = example_qas['sample_question'+str(i)] \n", + " format_args['sample_answers'+str(i)] = example_qas['sample_answers'+str(i)]\n", + " format_args['correct_answer'+str(i)] = example_qas['correct_answer'+str(i)]\n", + " return gen_mc_answer_lc_with_examples_prompt(num_examples).format(\n", + " chunk=row[chunk_col], question=row['question'], answers=row['randomized_answers'],\n", + " **format_args\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we'll experiment with just 2 examples." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "qa_df['long_ctx_with_examples_end_prompt'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_end', 'qa_long_ctx_prompt_end', qa_df, num_examples=2), axis=1)\n", + "\n", + "qa_df['long_ctx_with_examples_middle_prompt'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_middle', 'qa_long_ctx_prompt_middle', qa_df, num_examples=2), axis=1)\n", + "\n", + "qa_df['long_ctx_with_examples_beginning_prompt'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_beginning', 'qa_long_ctx_prompt_beginning', qa_df, num_examples=2), axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_end = await get_completions_parallel(CLIENT, qa_df.long_ctx_with_examples_end_prompt.values, output_col_name='qa_answers_long_ctx_examples_end')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_end), left_on='long_ctx_with_examples_end_prompt', right_on='prompt', suffixes=['_a', '_b'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_middle = await get_completions_parallel(CLIENT, qa_df.long_ctx_with_examples_middle_prompt.values, output_col_name='qa_answers_long_ctx_examples_middle')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_middle), left_on='long_ctx_with_examples_middle_prompt', right_on='prompt', suffixes=['_c', '_d'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_beginning = await get_completions_parallel(CLIENT, qa_df.long_ctx_with_examples_beginning_prompt.values, output_col_name='qa_answers_long_ctx_examples_beginning')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_beginning), left_on='long_ctx_with_examples_beginning_prompt', right_on='prompt', suffixes=['_e', '_f'])" + ] + }, + { + "cell_type": "code", + "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Results: 362 9\n", - "Results: 340 31\n", - "Results: 334 37\n" + "Results: 374 15\n", + "Results: 357 32\n", + "Results: 302 87\n" ] } ], "source": [ - "qa_df['qa_answers_long_ctx_end'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_end'].values]\n", - "print_results(qa_df, qa_df['qa_answers_long_ctx_end'].values)\n", + "qa_df['qa_answers_long_ctx_examples_end_parsed'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_examples_end'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_examples_end_parsed'].values)\n", "\n", - "qa_df['qa_answers_long_ctx_middle'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_middle'].values]\n", - "print_results(qa_df, qa_df['qa_answers_long_ctx_middle'].values)\n", + "qa_df['qa_answers_long_ctx_examples_middle_parsed'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_examples_middle'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_examples_middle_parsed'].values)\n", "\n", - "qa_df['qa_answers_long_ctx_beginning'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_beginning'].values]\n", - "print_results(qa_df, qa_df['qa_answers_long_ctx_beginning'].values)" + "qa_df['qa_answers_long_ctx_examples_beginning_parsed'] = [extract_between_tags('Answer', sample)[0][0] for sample in qa_df['qa_answers_long_ctx_examples_beginning'].values]\n", + "print_results(qa_df, qa_df['qa_answers_long_ctx_examples_beginning_parsed'].values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Claude got 98% accuracy when the relevant chunk was at the very end, 92% with it in the middle, and 91% with it at the beginning. (This monotonically decreasing trend as the info gets further from the question is an interesting divergence from the results of the [Lost in the Middle paper](https://arxiv.org/abs/2307.03172).)" + "Definitely better! What if we increase the number of examples to 5?" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "num_examples = 5\n", + "qa_df[f'long_ctx_with_examples_end_prompt_{num_examples}'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_end', 'qa_long_ctx_prompt_end', qa_df, num_examples=num_examples), axis=1)\n", + "\n", + "qa_df[f'long_ctx_with_examples_middle_prompt_{num_examples}'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_middle', 'qa_long_ctx_prompt_middle', qa_df, num_examples=num_examples), axis=1)\n", + "\n", + "qa_df[f'long_ctx_with_examples_beginning_prompt_{num_examples}'] = qa_df.apply(\n", + " lambda row: format_for_long_ctx_with_examples(row, 'long_context_beginning', 'qa_long_ctx_prompt_beginning', qa_df, num_examples=num_examples), axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_beginning_n = await get_completions_parallel(\n", + " CLIENT, qa_df[f'long_ctx_with_examples_beginning_prompt_{num_examples}'].values, output_col_name=f'qa_answers_long_ctx_examples_beginning_{num_examples}')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_beginning_n), left_on=f'long_ctx_with_examples_beginning_prompt_{num_examples}', right_on='prompt', suffixes=['_g', '_h'])" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 312 77\n" + ] + } + ], + "source": [ + "qa_df[f'qa_answers_long_ctx_examples_beginning_{num_examples}_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df[f'qa_answers_long_ctx_examples_beginning_{num_examples}'].values]\n", + "print_results(qa_df, qa_df[f'qa_answers_long_ctx_examples_beginning_{num_examples}_parsed'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_middle_n = await get_completions_parallel(\n", + " CLIENT, qa_df[f'long_ctx_with_examples_middle_prompt_{num_examples}'].values, output_col_name=f'qa_answers_long_ctx_examples_middle_{num_examples}')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_middle_n), left_on=f'long_ctx_with_examples_middle_prompt_{num_examples}', right_on='prompt', suffixes=['_i', '_j'])" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 362 27\n" + ] + } + ], + "source": [ + "qa_df[f'qa_answers_long_ctx_examples_middle_{num_examples}_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df[f'qa_answers_long_ctx_examples_middle_{num_examples}'].values]\n", + "print_results(qa_df, qa_df[f'qa_answers_long_ctx_examples_middle_{num_examples}_parsed'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qa_answers_long_ctx_examples_end_n = await get_completions_parallel(\n", + " CLIENT, qa_df[f'long_ctx_with_examples_end_prompt_{num_examples}'].values, output_col_name=f'qa_answers_long_ctx_examples_end_{num_examples}')\n", + "qa_df = qa_df.merge(pd.DataFrame(qa_answers_long_ctx_examples_end_n), left_on=f'long_ctx_with_examples_end_prompt_{num_examples}', right_on='prompt', suffixes=['_k', '_l'])" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results: 378 11\n" + ] + } + ], + "source": [ + "qa_df[f'qa_answers_long_ctx_examples_end_{num_examples}_parsed'] = [\n", + " extract_between_tags('Answer', sample)[0][0] \n", + " for sample in qa_df[f'qa_answers_long_ctx_examples_end_{num_examples}'].values]\n", + "print_results(qa_df, qa_df[f'qa_answers_long_ctx_examples_end_{num_examples}_parsed'].values)" ] }, { "cell_type": "markdown", "metadata": {}, - "source": [] + "source": [ + "Better still! Summarizing the results:\n", + "\n", + "| | | | |\n", + "| ------------------------- | ------------: | ---------: | ------: |\n", + "| Model: Claude Instant 1.2 | **Beginning** | **Middle** | **End** |\n", + "| **Just Ask** | 0.70 | 0.86 | 0.96 |\n", + "| **Generic Examples** | 0.72 | 0.87 | 0.96 |\n", + "| **2 Contextual Examples** | 0.78 | 0.92 | 0.96 |\n", + "| **5 Contextual Examples** | 0.80 | 0.93 | 0.97 |\n" + ] } ], "metadata": {