Skip to content

Commit

Permalink
feat: add upsert mode to sqlserver.py and corresponding tests (#2835)
Browse files Browse the repository at this point in the history
Co-authored-by: Anton Mantulo <[email protected]>
  • Loading branch information
AntonMantulo and Anton Mantulo committed May 24, 2024
1 parent a7396e8 commit f5e8182
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
29 changes: 26 additions & 3 deletions awswrangler/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,12 @@ def to_sql(
con: "pyodbc.Connection",
table: str,
schema: str,
mode: Literal["append", "overwrite"] = "append",
mode: Literal["append", "overwrite", "upsert"] = "append",
index: bool = False,
dtype: dict[str, str] | None = None,
varchar_lengths: dict[str, int] | None = None,
use_column_names: bool = False,
upsert_conflict_columns: list[str] | None = None,
chunksize: int = 200,
fast_executemany: bool = False,
) -> None:
Expand All @@ -457,7 +458,12 @@ def to_sql(
schema : str
Schema name
mode : str
Append or overwrite.
Append, overwrite or upsert.
- append: Inserts new records into table.
- overwrite: Drops table and recreates.
- upsert: Perform an upsert which checks for conflicts on columns given by ``upsert_conflict_columns`` and sets the new values on conflicts. Note that column names of the Dataframe will be used for this operation, as if ``use_column_names`` was set to True.
index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand All @@ -471,6 +477,8 @@ def to_sql(
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
inserted into the database columns `col1` and `col3`.
uspert_conflict_columns: List[str], optional
List of columns to be used as conflict columns in the upsert operation.
chunksize: int
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
fast_executemany: bool
Expand Down Expand Up @@ -506,6 +514,8 @@ def to_sql(
if df.empty is True:
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
_validate_connection(con=con)
if mode == "upsert" and not upsert_conflict_columns:
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> need to be set when using upsert mode.")
try:
with con.cursor() as cursor:
if fast_executemany:
Expand All @@ -524,15 +534,28 @@ def to_sql(
df.reset_index(level=df.index.names, inplace=True)
column_placeholders: str = ", ".join(["?"] * len(df.columns))
table_identifier = _get_table_identifier(schema, table)
column_names = [identifier(col, sql_mode="mssql") for col in df.columns]
quoted_columns = ", ".join(column_names)
insertion_columns = ""
if use_column_names:
quoted_columns = ", ".join(f"{identifier(col, sql_mode='mssql')}" for col in df.columns)
insertion_columns = f"({quoted_columns})"
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for placeholders, parameters in placeholder_parameter_pair_generator:
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES {placeholders}"
if mode == "upsert" and upsert_conflict_columns:
merge_on_columns = [identifier(col, sql_mode="mssql") for col in upsert_conflict_columns]
sql = f"MERGE INTO {table_identifier}\nUSING (VALUES {placeholders}) AS source ({quoted_columns})\n"
sql += f"ON {' AND '.join(f'{table_identifier}.{col}=source.{col}' for col in merge_on_columns)}\n"
sql += (
f"WHEN MATCHED THEN\n UPDATE "
f"SET {', '.join(f'{col}=source.{col}' for col in column_names)}\n"
)
sql += (
f"WHEN NOT MATCHED THEN\n INSERT "
f"({quoted_columns}) VALUES ({', '.join([f'source.{col}' for col in column_names])});"
)
_logger.debug("sql: %s", sql)
cursor.executemany(sql, (parameters,))
con.commit()
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,74 @@ def test_dfs_are_equal_for_different_chunksizes(sqlserver_table, sqlserver_con,
df["c1"] = df["c1"].astype("string")

assert df.equals(df2)


def test_upsert(sqlserver_table, sqlserver_con):
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})

with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=None,
use_column_names=True,
)

wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
df2 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df2) == 2)

wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]})
wr.sqlserver.to_sql(
df=df3,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
df4 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df4) == 3)

df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]})
wr.sqlserver.to_sql(
df=df5,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)

df6 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df6) == 3)
assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1)
assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1)

0 comments on commit f5e8182

Please sign in to comment.