Skip to content

Commit

Permalink
Add threading to save memory
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Sep 28, 2023
1 parent f3e9726 commit 060132c
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 12 deletions.
2 changes: 1 addition & 1 deletion app/components/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def parse_single_component(component_type: str, markdown: str, key: str = None):
}
match component_type:
case ComponentNames.exercise:
exercise_data = re.split(r"\nAnswer\n|\nAnswer:\n|\nAnswer: ", markdown)
exercise_data = re.split(r"\nAnswer\n|\nAnswer:\n|\nAnswer: |\nSolution: |\nSolution:\n|\nSolution\n", markdown)
match len(exercise_data):
case 1:
instructions, answer = exercise_data[0], None
Expand Down
2 changes: 1 addition & 1 deletion app/lesson/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def render_components_to_markdown(components: List[AllLessonComponentData]) -> s
tuples.append(
(
component.type,
f"Instructions\n\n{component.instructions}\n\nAnswer\n\n{component.solution}",
f"Instructions:\n\n{component.instructions}\n\nSolution:\n\n{component.solution}",
)
)
case _:
Expand Down
4 changes: 2 additions & 2 deletions app/llm/examples/lesson.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
],
[
"exercise",
"\nInstructions\n\nUsing the in operator, check whether the following values exist as dictionary keys in the content_ratings dictionary from earlier:\n- The string '9+'\n- The integer `987`\n\nAnswer\n\n```python\nis_in_1 = '9+' in content_ratings\nis_in_2 = 987 in content_ratings\n```\n\n"
"\nInstructions:\n\nUsing the in operator, check whether the following values exist as dictionary keys in the content_ratings dictionary from earlier:\n- The string '9+'\n- The integer `987`\n\nSolution:\n\n```python\nis_in_1 = '9+' in content_ratings\nis_in_2 = 987 in content_ratings\n```\n\n"
]
]
},
Expand All @@ -41,7 +41,7 @@
],
[
"exercise",
"\nInstructions\n\nWhy do you think the colonists needed help from Native Americans?.\n\nAnswer\n\nSettlers might have needed to get help from native tribes because they were facing starvation and disease. Coming from settled Europe, they lacked key survival skills, and resupply over the ocean was inconsistent and expensive.\n\n"
"\nInstructions:\n\nWhy do you think the colonists needed help from Native Americans?.\n\nSolution:\n\nSettlers might have needed to get help from native tribes because they were facing starvation and disease. Coming from settled Europe, they lacked key survival skills, and resupply over the ocean was inconsistent and expensive.\n\n"
]
]
}
Expand Down
2 changes: 1 addition & 1 deletion app/llm/generators/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CourseGeneratedConcepts(BaseModel):
concept_settings = GenerationSettings(
temperature=0.7,
max_tokens=256,
timeout=20,
timeout=40,
stop_tokens=None,
prompt_type="concept",
model=settings.LLM_INSTRUCT_TYPE,
Expand Down
5 changes: 4 additions & 1 deletion app/llm/generators/title.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from app.util import extract_only_json_list

title_settings = GenerationSettings(
temperature=0.9, max_tokens=512, timeout=20, prompt_type="title"
temperature=0.9,
max_tokens=512,
timeout=40,
prompt_type="title"
)


Expand Down
2 changes: 1 addition & 1 deletion app/llm/generators/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
topic_settings = GenerationSettings(
temperature=0.9,
max_tokens=512,
timeout=20,
timeout=40,
stop_tokens=None,
prompt_type="topic",
model=settings.LLM_INSTRUCT_TYPE,
Expand Down
3 changes: 3 additions & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Settings(BaseSettings):
SERPAPI_KEY: str = ""
SEARCH_BACKEND: Optional[str] = "serply"

# General
THREADS_PER_WORKER: int = 4 # How many threads to use per worker process to save RAM

class Config:
env_file = find_dotenv("local.env")

Expand Down
24 changes: 19 additions & 5 deletions book_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import math
import traceback
from dataclasses import dataclass
from typing import List, Dict, Optional
Expand Down Expand Up @@ -73,14 +74,18 @@ async def generate_single_course(course_name, outline_items=12):
return course


def process_course(course):
async def _process_courses(courses):
try:
return asyncio.run(generate_single_course(course))
return await asyncio.gather(*[generate_single_course(course) for course in courses], return_exceptions=True)
except Exception as e:
debug_print_trace()
print(f"Unhandled error generating course: {e}")


def process_courses(courses):
return asyncio.run(_process_courses(courses))


def load_topics(in_file: str, max_topics: Optional[str]):
with open(os.path.join(settings.DATA_DIR, in_file)) as f:
topics = json.load(f)
Expand All @@ -104,20 +109,29 @@ def load_topics(in_file: str, max_topics: Optional[str]):

topics = load_topics(args.in_file, max_topics=args.max)

courses = process_map(process_course, topics, max_workers=args.workers, chunksize=1)
total_processes = math.ceil(args.workers / settings.THREADS_PER_WORKER)

# group topics into batches of settings.THREADS_PER_WORKER
batched_topics = [topics[i:i + settings.THREADS_PER_WORKER] for i in range(0, len(topics), settings.THREADS_PER_WORKER)]

courses = process_map(process_courses, batched_topics, max_workers=total_processes, chunksize=1)

# Flatten courses list
courses = [course for batch in courses for course in batch]

with open(os.path.join(settings.DATA_DIR, args.out_file), "w+") as f:
for course, topic in zip(courses, topics):

# Filter out courses that didn't generate properly
if course is None or course.markdown is None or len(course.markdown) == 0:
if course is None or isinstance(course, Exception) or course.markdown is None or len(course.markdown) == 0:
continue
json_data = {
"topic": topic,
"model": settings.LLM_TYPE,
"concepts": course.concepts,
"outline": course.outline,
"markdown": course.markdown
"markdown": course.markdown,
"components": course.components
}
f.write(json.dumps(json_data) + '\n')

Expand Down
40 changes: 40 additions & 0 deletions scripts/clear_model_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from sqlalchemy import delete
from sqlmodel import select

from app.course.models import Course
from app.db.session import get_session
from app.llm.models import Prompt

import argparse


async def clear_model_data(model: str, dry_run=False):
async with get_session() as db:
query = await db.exec(select(Prompt).where(Prompt.model == model))
prompts = query.all()
print(f"Found {len(prompts)} prompts for model {model}.")
if not dry_run:
await db.exec(delete(Prompt).where(Prompt.model == model))

query = await db.exec(select(Course).where(Course.model == model))
courses = query.all()
print(f"Found {len(courses)} courses for model {model}.")
if not dry_run:
await db.exec(delete(Course).where(Course.model == model))
await db.commit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Remove all data related to a specific model.")
parser.add_argument("model", help="model name")
parser.add_argument("--dry-run", action="store_true", help="Don't actually delete anything.")
args = parser.parse_args()

asyncio.run(clear_model_data(args.model, args.dry_run))



0 comments on commit 060132c

Please sign in to comment.