From 43d9a49444ea094bcc52b093c43a66fd7865e452 Mon Sep 17 00:00:00 2001 From: Laura Gutierrez Funderburk Date: Thu, 4 May 2023 10:13:20 -0700 Subject: [PATCH] return query as part of json object --- src/app/app.py | 42 +++++++++++------------------------------- src/app/app_utils.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/src/app/app.py b/src/app/app.py index 5205d78..dc5d99b 100644 --- a/src/app/app.py +++ b/src/app/app.py @@ -1,30 +1,14 @@ from fastapi import FastAPI -from pydantic import BaseModel import os from dotenv import load_dotenv -import pandas as pd -from pathlib import Path -from .app_utils import Prompter, init_database, data_cleaner - -def init_data(): - # Set the path to the raw data - # Convert the current working directory to a Path object - script_dir = Path(os.getcwd()) - predicted_data_path = script_dir / 'data' / 'predicted-data' / 'vehicle_data_with_clusters.csv' - - # Load the CSV file into a DataFrame - dirty_df = pd.read_csv(predicted_data_path) - global df - df = data_cleaner(dirty_df) - global sample_values - sample_values = {df.columns[i]: df.values[0][i] for i in range(len(df.columns))} +from .app_utils import Prompter, init_database, init_prompt, init_data +from pydantic import BaseModel - return df, sample_values +class Query(BaseModel): + query: str app = FastAPI() -class Query(BaseModel): - query: str @app.get("/") async def root(): @@ -42,6 +26,8 @@ async def startup_event(): global prompter prompter = Prompter(openai_api_key, "gpt-4") + # Initialize data + global df, sample_values df, sample_values = init_data() # Set up engine @@ -50,16 +36,10 @@ async def startup_event(): @app.post("/search") async def search(query: Query): - # Generate SQL query - datagen_prompts = [ - {"role" : "system", "content" : "You are a data analyst specializing in SQL, you are presented with a natural language query, and you form queries to answer questions about the data."}, - {"role" : "user", "content" : f"Please generate 1 SQL queries for data with columns {', '.join(df.columns)} and sample values {sample_values}. \ - The table is called 'vehicleDB'. Use the natural language query {query.query}"}, - ] - - result = prompter.prompt_model_return(datagen_prompts) - print(result) - sql_query = result.split("\n\n")[0] + + # Initialize prompt + sql_query = init_prompt(query, prompter, df, sample_values) + print(sql_query) try: # Execute SQL query and fetch results @@ -70,8 +50,8 @@ async def search(query: Query): # Convert rows to list of dicts for JSON response columns = result.keys() data = [dict(zip(columns, row)) for row in rows] + return {"data": data, "sql_query": sql_query, "status": "success"} except Exception as e: return {"error": f"SQL query failed. {e}"} - return {"data": data} diff --git a/src/app/app_utils.py b/src/app/app_utils.py index 0456cf6..2a33a2b 100644 --- a/src/app/app_utils.py +++ b/src/app/app_utils.py @@ -3,6 +3,10 @@ import openai from sqlalchemy.engine import create_engine from pathlib import Path +from pydantic import BaseModel + +class Query(BaseModel): + query: str class Prompter: def __init__(self, api_key, gpt_model): @@ -20,6 +24,46 @@ def prompt_model_return(self, messages: list): temperature=0.2) return response["choices"][0]["message"]["content"] + + +def init_prompt(query: Query, prompter: Prompter, df: pd.DataFrame, sample_values: dict): + # Generate SQL query + system_content = "You are a data analyst specializing in SQL, \ + you are presented with a natural language query, \ + and you form queries to answer questions about the data." + user_content = f"Please generate 1 SQL queries for data with \ + columns {', '.join(df.columns)} \ + and sample values {sample_values}. \ + The table is called 'vehicleDB'. \ + Use the natural language query {query.query}" + datagen_prompts = [ + {"role" : "system", "content" : system_content}, + {"role" : "user", "content" : user_content}, + ] + + # Take parameters and form a SQL query + sql_result= prompter.prompt_model_return(datagen_prompts) + + # Sometimes the query is verbose - adding unnecessary explanations + sql_query = sql_result.split("\n\n")[0] + + return sql_query + +def init_data(): + # Set the path to the raw data + # Convert the current working directory to a Path object + script_dir = Path(os.getcwd()) + predicted_data_path = script_dir / 'data' / 'predicted-data' / 'vehicle_data_with_clusters.csv' + + # Load the CSV file into a DataFrame + dirty_df = pd.read_csv(predicted_data_path) + global df + df = data_cleaner(dirty_df) + global sample_values + sample_values = {df.columns[i]: df.values[0][i] for i in range(len(df.columns))} + + return df, sample_values + def data_cleaner(df):