Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PySparkler: Integrate with sqlfluff and sqlfluff-plugin-sparksql-upgrade #79

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PySparkler: Sqlfluff upgrades a formatted string SQL and provides cod…
…e hint for ones with complex expressions within.
  • Loading branch information
dhruv-pratap committed Jun 27, 2023
commit 395d70b7433fbda93e27013851287d085acd41cd
91 changes: 75 additions & 16 deletions pysparkler/pysparkler/sql_21_to_33.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,104 @@ def __init__(
super().__init__(
transformer_id="PY21-33-001",
comment="Please note, PySparkler makes a best effort to upcast SQL statements directly being executed. \
However, the upgrade won't be possible for certain templated SQLs, and in those scenarios please de-template the SQL \
and use the Sqlfluff tooling to upcast the SQL yourself.",
However, the upcast won't be possible for certain formatted string SQL having complex expressions within, and in those \
cases please de-template the SQL and use the Sqlfluff tooling to upcast the SQL yourself.",
)
self.sql_upgraded = False

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
"""Check if the call is a SQL statement and try to upcast it"""
print(f"******** Call node\n{original_node}")
if m.matches(
updated_node,
original_node,
m.Call(
func=m.Attribute(
attr=m.Name("sql"),
),
args=[
m.Arg(
value=m.SimpleString(),
value=m.OneOf(
m.SimpleString(),
m.FormattedString(),
m.ConcatenatedString(),
)
)
],
),
):
print(f"******** Match found\n{original_node}")
self.match_found = True
sql_node: cst.SimpleString = updated_node.args[0].value
sql = sql_node.evaluated_value
sql_node: cst.BaseExpression = updated_node.args[0].value
try:
updated_sql = self.do_fix(sql)
if updated_sql != sql:
updated_sql_value = (
sql_node.prefix + sql_node.quote + updated_sql + sql_node.quote
if isinstance(sql_node, cst.SimpleString):
updated_sql_node = self.update_simple_string_sql(sql_node)
elif isinstance(sql_node, cst.FormattedString):
updated_sql_node = self.update_formatted_string_sql(sql_node)
else:
raise NotImplementedError(
f"Unsupported SQL expression encountered : {sql_node}"
)
changes = updated_node.with_changes(
args=[cst.Arg(value=cst.SimpleString(value=updated_sql_value))]

if self.sql_upgraded:
self.comment = "Spark SQL statement has been upgraded to Spark 3.3 compatible syntax."
self.sql_upgraded = False
else:
self.comment = (
"Spark SQL statement has Spark 3.3 compatible syntax."
)
return changes

return updated_node.with_changes(args=[cst.Arg(value=updated_sql_node)])
except Exception as e: # pylint: disable=broad-except
print(f"Failed to parse SQL: {sql} with error: {e}")
print(f"Failed to parse SQL: {sql_node} with error: {e}")
self.comment = "Unable to inspect the Spark SQL statement since the formatted string SQL has complex \
expressions within. Please de-template the SQL and use the Sqlfluff tooling to upcast the SQL yourself."
self.sql_upgraded = False

return updated_node

def update_simple_string_sql(self, sql_node: cst.SimpleString) -> cst.SimpleString:
sql = sql_node.evaluated_value
updated_sql = self.do_fix(sql)
if updated_sql != sql:
self.sql_upgraded = True
updated_sql_value = (
sql_node.prefix + sql_node.quote + updated_sql + sql_node.quote
)
return cst.SimpleString(value=updated_sql_value)
else:
return sql_node

def update_formatted_string_sql(
self, sql_node: cst.FormattedString
) -> cst.FormattedString:
# Form the raw SQL string by concatenating all the parts
sql = ""
for part in sql_node.parts:
if isinstance(part, cst.FormattedStringText):
sql += part.value
elif isinstance(part, cst.FormattedStringExpression) and isinstance(
part.expression, cst.Name
):
sql += (
part.whitespace_before_expression.value
+ "{"
+ part.expression.value
+ "}"
+ part.whitespace_after_expression.value
)
else:
raise NotImplementedError(
f"Unsupported formatted string expression encountered : {part}"
)

updated_sql = self.do_fix(sql)
if updated_sql != sql:
self.sql_upgraded = True
updated_sql_value = (
sql_node.prefix + sql_node.quote + updated_sql + sql_node.quote
)
return cst.parse_expression(updated_sql_value)
else:
return sql_node

@staticmethod
def do_fix(sql: str) -> str:
return sqlfluff.fix(
Expand Down
65 changes: 64 additions & 1 deletion pysparkler/tests/test_sql_21_to_33.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,70 @@ def test_upgrades_non_templated_sql():
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
result = spark.sql("select int(dateint) val from my_table limit 10") # PY21-33-001: Please note, PySparkler makes a best effort to upcast SQL statements directly being executed. However, the upgrade won't be possible for certain templated SQLs, and in those scenarios please de-template the SQL and use the Sqlfluff tooling to upcast the SQL yourself. # noqa: E501
result = spark.sql("select int(dateint) val from my_table limit 10") # PY21-33-001: Spark SQL statement has been upgraded to Spark 3.3 compatible syntax. # noqa: E501
spark.stop()
"""
assert modified_code == expected_code


def test_upgrades_templated_sql():
given_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
table_name = "my_table"
result = spark.sql(f"select cast(dateint as int) val from {table_name} limit 10")
spark.stop()
"""
modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter())
expected_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
table_name = "my_table"
result = spark.sql(f"select int(dateint) val from {table_name} limit 10") # PY21-33-001: Spark SQL statement has been upgraded to Spark 3.3 compatible syntax. # noqa: E501
spark.stop()
"""
assert modified_code == expected_code


def test_unable_to_upgrade_templated_sql_with_complex_expressions():
given_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
table_name = "my_table"
num = 10
result = spark.sql(f"select cast(dateint as int) val from {table_name} where x < {num * 100} limit 10")
spark.stop()
"""
modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter())
expected_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
table_name = "my_table"
num = 10
result = spark.sql(f"select cast(dateint as int) val from {table_name} where x < {num * 100} limit 10") # PY21-33-001: Unable to inspect the Spark SQL statement since the formatted string SQL has complex expressions within. Please de-template the SQL and use the Sqlfluff tooling to upcast the SQL yourself. # noqa: E501
spark.stop()
"""
assert modified_code == expected_code


def test_no_upgrades_required_after_inspecting_sql():
given_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
result = spark.sql("select * from my_table limit 10")
spark.stop()
"""
modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter())
expected_code = """\
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("SQL Example").getOrCreate()
result = spark.sql("select * from my_table limit 10") # PY21-33-001: Spark SQL statement has Spark 3.3 compatible syntax. # noqa: E501
spark.stop()
"""
assert modified_code == expected_code