Skip to content

Commit

Permalink
return query as part of json object
Browse files Browse the repository at this point in the history
  • Loading branch information
lfunderburk committed May 4, 2023
1 parent c88db26 commit 43d9a49
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 31 deletions.
42 changes: 11 additions & 31 deletions src/app/app.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
from fastapi import FastAPI
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import pandas as pd
from pathlib import Path
from .app_utils import Prompter, init_database, data_cleaner

def init_data():
# Set the path to the raw data
# Convert the current working directory to a Path object
script_dir = Path(os.getcwd())
predicted_data_path = script_dir / 'data' / 'predicted-data' / 'vehicle_data_with_clusters.csv'

# Load the CSV file into a DataFrame
dirty_df = pd.read_csv(predicted_data_path)
global df
df = data_cleaner(dirty_df)
global sample_values
sample_values = {df.columns[i]: df.values[0][i] for i in range(len(df.columns))}
from .app_utils import Prompter, init_database, init_prompt, init_data
from pydantic import BaseModel

return df, sample_values
class Query(BaseModel):
query: str

app = FastAPI()

class Query(BaseModel):
query: str

@app.get("/")
async def root():
Expand All @@ -42,6 +26,8 @@ async def startup_event():
global prompter
prompter = Prompter(openai_api_key, "gpt-4")

# Initialize data
global df, sample_values
df, sample_values = init_data()

# Set up engine
Expand All @@ -50,16 +36,10 @@ async def startup_event():

@app.post("/search")
async def search(query: Query):
# Generate SQL query
datagen_prompts = [
{"role" : "system", "content" : "You are a data analyst specializing in SQL, you are presented with a natural language query, and you form queries to answer questions about the data."},
{"role" : "user", "content" : f"Please generate 1 SQL queries for data with columns {', '.join(df.columns)} and sample values {sample_values}. \
The table is called 'vehicleDB'. Use the natural language query {query.query}"},
]

result = prompter.prompt_model_return(datagen_prompts)
print(result)
sql_query = result.split("\n\n")[0]

# Initialize prompt
sql_query = init_prompt(query, prompter, df, sample_values)
print(sql_query)

try:
# Execute SQL query and fetch results
Expand All @@ -70,8 +50,8 @@ async def search(query: Query):
# Convert rows to list of dicts for JSON response
columns = result.keys()
data = [dict(zip(columns, row)) for row in rows]
return {"data": data, "sql_query": sql_query, "status": "success"}

except Exception as e:
return {"error": f"SQL query failed. {e}"}

return {"data": data}
44 changes: 44 additions & 0 deletions src/app/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import openai
from sqlalchemy.engine import create_engine
from pathlib import Path
from pydantic import BaseModel

class Query(BaseModel):
query: str

class Prompter:
def __init__(self, api_key, gpt_model):
Expand All @@ -20,6 +24,46 @@ def prompt_model_return(self, messages: list):
temperature=0.2)
return response["choices"][0]["message"]["content"]



def init_prompt(query: Query, prompter: Prompter, df: pd.DataFrame, sample_values: dict):
# Generate SQL query
system_content = "You are a data analyst specializing in SQL, \
you are presented with a natural language query, \
and you form queries to answer questions about the data."
user_content = f"Please generate 1 SQL queries for data with \
columns {', '.join(df.columns)} \
and sample values {sample_values}. \
The table is called 'vehicleDB'. \
Use the natural language query {query.query}"
datagen_prompts = [
{"role" : "system", "content" : system_content},
{"role" : "user", "content" : user_content},
]

# Take parameters and form a SQL query
sql_result= prompter.prompt_model_return(datagen_prompts)

# Sometimes the query is verbose - adding unnecessary explanations
sql_query = sql_result.split("\n\n")[0]

return sql_query

def init_data():
# Set the path to the raw data
# Convert the current working directory to a Path object
script_dir = Path(os.getcwd())
predicted_data_path = script_dir / 'data' / 'predicted-data' / 'vehicle_data_with_clusters.csv'

# Load the CSV file into a DataFrame
dirty_df = pd.read_csv(predicted_data_path)
global df
df = data_cleaner(dirty_df)
global sample_values
sample_values = {df.columns[i]: df.values[0][i] for i in range(len(df.columns))}

return df, sample_values

def data_cleaner(df):


Expand Down

0 comments on commit 43d9a49

Please sign in to comment.