Skip to content

Commit

Permalink
RAFT Recovery Mode for interruptions (ShishirPatil#410)
Browse files Browse the repository at this point in the history
Implemented a "safe"/recovery mode that periodically saves chunks into
"checkpoint" datasets while also keeping track of the current chunk
number. In the case of an interruption or crash, the script resumes at
the saved chunk number. After all chunks have been processed, all
checkpoint datasets are concatenated and saved as one final dataset.

Added an argument allowing user to choose whether to run RAFT in safe or
fast mode (defaults to safe).

Close ShishirPatil#394

---------

Co-authored-by: Kaihao Wen <[email protected]>
  • Loading branch information
kaiwen129 and Kaihao Wen authored May 4, 2024
1 parent 4e8dac1 commit 624d371
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
5 changes: 3 additions & 2 deletions raft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ Arguments:
- `--openai_key` - your OpenAI key used to make queries to GPT-3.5 or GPT-4
- `--embedding-model` - The embedding model to use to encode documents chunks. Defaults to `text-embedding-ada-002`.
- `--completion-model` - The model to use to generate questions and answers. Defaults to `gpt-4`.
- `--fast` - Fast mode flag. By default, this flag is not included and the script runs in safe mode, where it saves checkpoint datasets, allowing the script to recover and continue where it left off in the case of an interruption. Include this flag to run RAFT without recovery.


## Usage with OpenAI API

Run the following command with your desired arguments to generate the dataset.
```bash
python3 raft.py --datapath PATH_TO_DATA --output OUTPUT_PATH --distractors 3 --doctype pdf --chunk_size 512 --questions 5 --openai_key YOUR_OPENAI_KEY
python3 raft.py --datapath PATH_TO_DATA --output OUTPUT_PATH --distractors 3 --p 1.0 --doctype pdf --chunk_size 512 --questions 5 --openai_key YOUR_OPENAI_KEY
```

**Note**: As an alternative to passing the OpenAI key with the `--openai_key` argument, you also store the standard OpenAI environment variables in a file called `.env` like so. All standard OpenAI env variables are supported.
Expand All @@ -49,7 +50,7 @@ OPENAI_API_KEY=<replace_me>

## Usage with Azure OpenAI API

Create a file `.env` like so. All standard Azure OpenAI environement variables are supported.
Create a file `.env` like so. All standard Azure OpenAI environment variables are supported.

```
# Azure OpenAI API
Expand Down
69 changes: 61 additions & 8 deletions raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import Literal, Any
import argparse
from openai import OpenAI
import datasets
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
import json
import PyPDF2
import random
import os, shutil
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
from client_utils import build_openai_client, build_langchain_embeddings
Expand All @@ -21,6 +23,9 @@

DocType = Literal["api", "pdf", "json", "txt"]

# Every N chunks, save checkpoint
N = 15

def get_args() -> argparse.Namespace:
"""
Parses and returns the arguments specified by the user's command
Expand All @@ -35,8 +40,9 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
parser.add_argument("--doctype", type=str, default="pdf", help="The type of the document, must be one of the accepted doctypes", choices=["pdf", "txt", "json", "api"])
parser.add_argument("--openai_key", type=str, default=None, help="Your OpenAI key used to make queries to GPT-3.5 or GPT-4")
parser.add_argument("--embedding-model", type=str, default="text-embedding-ada-002", help="The embedding model to use to encode documents chunks (text-embedding-ada-002, ...)")
parser.add_argument("--completion-model", type=str, default="gpt-4", help="The model to use to generate questions and answers (gpt-3.5, gpt-4, ...)")
parser.add_argument("--embedding_model", type=str, default="text-embedding-ada-002", help="The embedding model to use to encode documents chunks (text-embedding-ada-002, ...)")
parser.add_argument("--completion_model", type=str, default="gpt-4", help="The model to use to generate questions and answers (gpt-3.5, gpt-4, ...)")
parser.add_argument("--fast", action="store_true", help="Run the script in fast mode (no recovery implemented)")

args = parser.parse_args()
return args
Expand Down Expand Up @@ -273,6 +279,14 @@ def add_chunk_to_dataset(
else:
ds = ds.add_item(datapt)

def save_checkpoint(state, filename):
with open(filename, 'w') as f:
f.write(str(state))

def load_checkpoint(filename):
with open(filename, 'r') as f:
return eval(f.read())

def main():
global ds

Expand All @@ -293,18 +307,57 @@ def main():
ds = None

num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

if not args.fast:
start = 0
if os.path.exists("checkpoint.txt"):
start = int(load_checkpoint("checkpoint.txt"))

for i in range((start//N)*N, len(chunks)):
chunk = chunks[i]
save_checkpoint(i, "checkpoint.txt")

perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

if (i+1) % N == 0:
ds.save_to_disk(args.output + "-checkpoints-" + str(i))
ds = None


if ds:
ds.save_to_disk(args.output + "-checkpoints-last")

ds_list = []

for filename in os.listdir(os.path.dirname(args.output)):
if "-checkpoints-" in filename:
for f in os.listdir(os.path.dirname(args.output) + "/" + filename):
if f.endswith(".arrow"):
ds_list.append(Dataset.from_file(os.path.dirname(args.output) + "/" + filename + "/" + f))

ds = datasets.concatenate_datasets(ds_list)
else:
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

# Save as .arrow format
ds.save_to_disk(args.output)

# Save as .jsonl format
ds.to_json(args.output + ".jsonl")

if not args.fast:
os.remove("checkpoint.txt")
for filename in os.listdir(os.path.dirname(args.output)):
if "-checkpoints-" in filename:
shutil.rmtree(os.path.dirname(args.output) + "/" + filename)

if __name__ == "__main__":
with MDC(progress="0%"):
main()
Binary file added raft/sample_data/UC_Berkeley_short.pdf
Binary file not shown.

0 comments on commit 624d371

Please sign in to comment.