Skip to content

Commit

Permalink
[BUGFIX] Snowflake - Fix private_key Unicode serialization errors (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Kilo59 committed Jun 18, 2024
1 parent b802cbe commit 76cc366
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: [--branch, develop, --branch, main]
- id: detect-private-key
exclude: tests/test_fixtures/database_key_test*
exclude: tests/test_fixtures/database_key_test*|tests/datasource/fluent/test_snowflake_datasource\.py
- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
Expand Down
60 changes: 52 additions & 8 deletions great_expectations/datasource/fluent/snowflake_datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import functools
import logging
import urllib.parse
Expand Down Expand Up @@ -45,6 +46,27 @@
MISSING: Final = object() # sentinel value to indicate missing values


@functools.lru_cache(maxsize=4)
def _is_b64_encoded(sb: str | bytes) -> bool:
"""
Check if a string or bytes is base64 encoded.
By decoding and then encoding it, we can check if it's the same.
Copied from: https://stackoverflow.com/a/45928164/6304433
"""
try:
if isinstance(sb, str):
# If there's any unicode here, an exception will be thrown and the function will return false
sb_bytes = bytes(sb, "ascii")
elif isinstance(sb, bytes):
sb_bytes = sb
else:
raise TypeError("Argument must be string or bytes")
return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
except Exception:
return False


@functools.lru_cache(maxsize=4)
def _get_database_and_schema_from_path(
url_path: str,
Expand Down Expand Up @@ -351,6 +373,18 @@ def _asset_forward_compatibility(cls, assets: list[dict]) -> list[dict]:
LOGGER.error(f"Error attempting forward compatibility modifications: {e!r}")
return assets

@pydantic.validator("kwargs")
def _base64_encode_private_key(cls, kwargs: dict) -> dict:
if connect_args := kwargs.get("connect_args", {}):
if private_key := connect_args.get("private_key"):
# test if it's already base64 encoded
if _is_b64_encoded(private_key):
LOGGER.info("private_key is already base64 encoded")
else:
LOGGER.info("private_key is not base64 encoded, encoding now")
connect_args["private_key"] = base64.standard_b64encode(private_key)
return kwargs

class Config:
@staticmethod
def schema_extra(schema: dict, model: type[SnowflakeDatasource]) -> None:
Expand Down Expand Up @@ -437,9 +471,8 @@ def get_engine(self) -> sqlalchemy.Engine:
"application": self._get_snowflake_partner_application()
}
)
self._engine = sa.create_engine(
url,
**kwargs,
self._engine = self._build_engine_with_connect_args(
url=url, **kwargs
)
else:
self._engine = self._build_engine_with_connect_args(
Expand All @@ -461,21 +494,32 @@ def get_engine(self) -> sqlalchemy.Engine:
return self._engine

def _build_engine_with_connect_args(
self, connect_args: dict[str, Any] | None = None, **kwargs
self,
url: URL | None = None,
connect_args: dict[str, Any] | None = None,
**kwargs,
) -> sqlalchemy.Engine:
url_args = self._get_url_args()
url_args.update(kwargs)
if not url:
url_args = self._get_url_args()
url_args.update(kwargs)
url = URL(**url_args)
else:
url_args = {}

engine_kwargs: dict[Literal["url", "connect_args"], Any] = {}
if connect_args:
if connect_args.get("private_key"):
if private_key := connect_args.get("private_key"):
url_args.pop( # TODO: update models + validation to handle this
"password", None
)
LOGGER.info(
"private_key detected, ignoring password and using private_key for authentication"
)
# assume the private_key is base64 encoded
connect_args["private_key"] = base64.standard_b64decode(private_key)

engine_kwargs["connect_args"] = connect_args
engine_kwargs["url"] = URL(**url_args)

engine_kwargs["url"] = url

return sa.create_engine(**engine_kwargs)
6 changes: 6 additions & 0 deletions tests/datasource/fluent/great_expectations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ fluent_datasources:
id: e9ac5d80-679b-49d2-8c60-3bbf6530dbb4
type: table
table_name: my_table
my_snow_flake_w_private_key:
type: snowflake
connection_string: "snowflake:https://user_login_name:dummy_value@account_identifier"
kwargs:
connect_args:
private_key: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlDWEFJQkFBS0JnSHBGZENkS09HTGFpTUg5dDF0aDFsS3FKVmNEd2ZubFAybHBuZUFOYmJzZ0hiNi80VTJVCnVhMDg1emxOWWhaNXhKc25TZHFJQWZyYWd6dVZZTmsyT0Nwb04xUWtxNG9XYWQwYTRjRUIyUUJ0UDlqczBkVlcKeFFPYkpNOHQxWkxIQjNMdzFOQ3FCNk9lZmtQN1hsRTB3NmFYUlo1SVd3dlZDODZjQlhWQm1CWHpBZ01CQUFFQwpnWUFCUjZUVm5ITkdwWjcwMk9FSWRkZTJlYzEyUWJYUUZkUTZHRDdzejNjc2xFTjdjYXE4RXloMlpjTE4yTCtFCkdMWTBJWThtV0hJYzNCaXZrUHE0aTFhL0p5UlV6RVRvSnZqVmQ4SjFzbHJ6ejhyeU1PQWlQYnh0MzNJcGdHTDMKLzhLZ09MWXhqZGc1YnBuNnNDWmxPWHk3V1lqbDFIOFRCdzhDelpGNDFIYTI0UUpCQU03VSs4bTBoeWtuYm5CRApnS1hHYjBlSElCeDB6bFBhTkp3REhVY0pYdWp4YlZmd1ZqS1dMeTA3Sm9YUmlBZ1B1VnN6SU1odTByK1hhODdMClcyV0xkVHNDUVFDWFZtMEhlN1NheXRucmxBRmNrNS9MNEVqdFdhQVFHZm1WNGVhd0kySGVtV01qajB0dWtkRnQKd0FXSER1S1lNYitiZzIxT1UyWFF4b2xsWVlKZmsvYXBBa0JhU2UxMFd1Tloyc1hDS2lXQnVJTWhaV0ptS2JOYwpOWGdiMXR3MEEybzBKQmhJZURrWXNpajhCTU5IVFhXbGx6K2lDVXE1VkcrWmhYOWhjYkovUElhN0FrRUFqZmdkCnYrOWt0ZkdtRFVHREpYMjNZbUs5Qnl3VTVBWDZCWWt1Qi82cFNWRkxsNGhOa3lSbit6VXYra3NVZHdIMFpjY2QKTzJVeEZuR3BZdG5lbkJzS1FRSkJBTVgydGdGY2cvL3QxTGk0K2R4bFR2WlovY2xaQ0xwV1hwNEhRZ0J3enhNTgp3cERvRjQwT3pOWXJyS0lib1U0QkpGT01PQldBUzRERkRZR2RmTFZTOTlnPQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==
my_databricks_sql_ds:
type: databricks_sql
connection_string: "databricks:https://token:[email protected]:123/default_db?http_path=/sql/1.0/warehouses/abc123&catalog=default&schema=dev"
Expand Down
64 changes: 59 additions & 5 deletions tests/datasource/fluent/test_snowflake_datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import logging
from pprint import pformat as pf
from sys import version_info as python_version
Expand Down Expand Up @@ -27,6 +28,29 @@

TEST_LOGGER: Final = logging.getLogger(__name__)

_EXAMPLE_PRIVATE_KEY: Final[
bytes
] = b"""-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgHpFdCdKOGLaiMH9t1th1lKqJVcDwfnlP2lpneANbbsgHb6/4U2U
ua085zlNYhZ5xJsnSdqIAfragzuVYNk2OCpoN1Qkq4oWad0a4cEB2QBtP9js0dVW
xQObJM8t1ZLHB3Lw1NCqB6OefkP7XlE0w6aXRZ5IWwvVC86cBXVBmBXzAgMBAAEC
gYABR6TVnHNGpZ702OEIdde2ec12QbXQFdQ6GD7sz3cslEN7caq8Eyh2ZcLN2L+E
GLY0IY8mWHIc3BivkPq4i1a/JyRUzEToJvjVd8J1slrzz8ryMOAiPbxt33IpgGL3
/8KgOLYxjdg5bpn6sCZlOXy7WYjl1H8TBw8CzZF41Ha24QJBAM7U+8m0hyknbnBD
gKXGb0eHIBx0zlPaNJwDHUcJXujxbVfwVjKWLy07JoXRiAgPuVszIMhu0r+Xa87L
W2WLdTsCQQCXVm0He7SaytnrlAFck5/L4EjtWaAQGfmV4eawI2HemWMjj0tukdFt
wAWHDuKYMb+bg21OU2XQxollYYJfk/apAkBaSe10WuNZ2sXCKiWBuIMhZWJmKbNc
NXgb1tw0A2o0JBhIeDkYsij8BMNHTXWllz+iCUq5VG+ZhX9hcbJ/PIa7AkEAjfgd
v+9ktfGmDUGDJX23YmK9BywU5AX6BYkuB/6pSVFLl4hNkyRn+zUv+ksUdwH0Zccd
O2UxFnGpYtnenBsKQQJBAMX2tgFcg//t1Li4+dxlTvZZ/clZCLpWXp4HQgBwzxMN
wpDoF40OzNYrrKIboU4BJFOMOBWAS4DFDYGdfLVS99g=
-----END RSA PRIVATE KEY-----"""

_EXAMPLE_B64_ENCODED_PRIVATE_KEY: Final[bytes] = base64.standard_b64encode(
_EXAMPLE_PRIVATE_KEY
)


VALID_DS_CONFIG_PARAMS: Final[Sequence[ParameterSet]] = [
param(
{
Expand Down Expand Up @@ -137,6 +161,28 @@
},
id="min connection_string dict with password ConfigStr",
),
param(
{
"connection_string": {
"user": "my_user",
"password": "DUMMY_VALUE",
"account": "my_account",
},
"kwargs": {"connect_args": {"private": _EXAMPLE_PRIVATE_KEY}},
},
id="private_key auth",
),
param(
{
"connection_string": {
"user": "my_user",
"password": "DUMMY_VALUE",
"account": "my_account",
},
"kwargs": {"connect_args": {"private": _EXAMPLE_B64_ENCODED_PRIVATE_KEY}},
},
id="private_key auth b64 encoded",
),
]


Expand Down Expand Up @@ -678,7 +724,7 @@ def test_get_engine_correctly_sets_application_query_param(
"name": "std connection_str",
"connection_string": "snowflake:https://user:password@account/db/schema?warehouse=wh&role=role",
},
{},
{"url": ANY},
id="std connection_string str",
),
param(
Expand All @@ -694,7 +740,9 @@ def test_get_engine_correctly_sets_application_query_param(
"role": "role",
},
},
{},
{
"url": "snowflake:https://user:password@account/db/schema?application=great_expectations_core&role=role&warehouse=wh",
},
id="std connection_string dict",
),
param(
Expand All @@ -703,7 +751,10 @@ def test_get_engine_correctly_sets_application_query_param(
"connection_string": "snowflake:https://user:password@account/db/schema?warehouse=wh&role=role",
"kwargs": {"connect_args": {"private_key": b"my_key"}},
},
{"connect_args": {"private_key": b"my_key"}},
{
"connect_args": {"private_key": b"my_key"},
"url": ANY,
},
id="connection_string str with connect_args",
),
param(
Expand All @@ -720,7 +771,10 @@ def test_get_engine_correctly_sets_application_query_param(
},
"kwargs": {"connect_args": {"private_key": b"my_key"}},
},
{"connect_args": {"private_key": b"my_key"}},
{
"connect_args": {"private_key": b"my_key"},
"url": "snowflake:https://user:password@account/db/schema?application=great_expectations_core&role=role&warehouse=wh",
},
id="connection_string dict with connect_args",
),
],
Expand All @@ -739,7 +793,7 @@ def test_create_engine_is_called_with_expected_kwargs(
engine = datasource.get_engine()
print(engine)

create_engine_spy.assert_called_once_with(ANY, **expected_called_with)
create_engine_spy.assert_called_once_with(**expected_called_with)


@pytest.mark.snowflake
Expand Down

0 comments on commit 76cc366

Please sign in to comment.