Skip to content

Commit

Permalink
Merge pull request #248 from vanna-ai/db-connection-updates
Browse files Browse the repository at this point in the history
Add MS SQL connection
  • Loading branch information
zainhoda committed Feb 21, 2024
2 parents ef350bc + 6b0393f commit b3d46dc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.1.0"
version = "0.1.1"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand All @@ -18,7 +18,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask"
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy"
]

[project.urls]
Expand Down
51 changes: 50 additions & 1 deletion src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,10 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
except psycopg2.Error as e:
conn.rollback()
raise ValidationError(e)

except Exception as e:
conn.rollback()
raise e

self.run_sql_is_set = True
self.run_sql = run_sql_postgres
Expand Down Expand Up @@ -829,7 +833,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
url (str): The URL of the database to connect to.
url (str): The URL of the database to connect to. Use :memory: to create an in-memory database. Use md: or motherduck: to use the MotherDuck database.
init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
Returns:
Expand All @@ -850,6 +854,8 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
print(os.path.exists(url))
if os.path.exists(url):
path=url
elif url.startswith("md") or url.startswith("motherduck"):
path = url
else:
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
Expand All @@ -870,6 +876,49 @@ def run_sql_duckdb(sql: str):
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

def connect_to_mssql(self, odbc_conn_str: str):
"""
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
odbc_conn_str (str): The ODBC connection string.
Returns:
None
"""
try:
import pyodbc
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install pyodbc"
)

try:
from sqlalchemy.engine import URL
import sqlalchemy as sa
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install sqlalchemy"
)

connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": odbc_conn_str})

from sqlalchemy import create_engine
engine = create_engine(connection_url)

def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
with engine.begin() as conn:
df = pd.read_sql_query(sa.text(sql), conn)
return df

raise Exception("Couldn't run sql")

self.run_sql = run_sql_mssql
self.run_sql_is_set = True

def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
"""
Example:
Expand Down

0 comments on commit b3d46dc

Please sign in to comment.