Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor utils, add tests, move exceptions into separate module #264

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions simple_ddl_parser/ddl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ply.lex import LexToken

from simple_ddl_parser.exception import DDLParserError
from simple_ddl_parser import tokens as tok
from simple_ddl_parser.dialects import (
HQL,
Expand All @@ -19,10 +20,6 @@
from simple_ddl_parser.parser import Parser


class DDLParserError(Exception):
pass


class Dialects(
SparkSQL,
Snowflake,
Expand Down
8 changes: 8 additions & 0 deletions simple_ddl_parser/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
__all__ = [
"DDLParserError",
]


class DDLParserError(Exception):
""" Base exception in simple ddl parser library """
pass
2 changes: 1 addition & 1 deletion simple_ddl_parser/output/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def group_by_type_result(self) -> None:
else:
_type.extend(item["comments"])
break
if result_as_dict["comments"] == []:
if not result_as_dict["comments"]:
del result_as_dict["comments"]

self.final_result = result_as_dict
Expand Down
39 changes: 19 additions & 20 deletions simple_ddl_parser/output/table_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
from simple_ddl_parser.output.base_data import BaseData
from simple_ddl_parser.output.dialects import CommonDialectsFieldsMixin, dialect_by_name

__all__ = [
"TableData",
]


def _pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> None:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if kwargs.get("fields_terminated_by") == "_ddl_parser_comma_only_str":
kwargs["fields_terminated_by"] = "','"


class TableData:
cls_prefix = "Dialect"
Expand All @@ -13,34 +28,18 @@ def get_dialect_class(cls, kwargs: dict):

if output_mode and output_mode != "sql":
main_cls = dialect_by_name.get(output_mode)
cls = dataclass(
return dataclass(
type(
f"{main_cls.__name__}{cls.cls_prefix}",
(main_cls, CommonDialectsFieldsMixin),
{},
)
)
else:
cls = BaseData

return cls

@staticmethod
def pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> dict:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if (
"fields_terminated_by" in kwargs
and "_ddl_parser_comma_only_str" == kwargs["fields_terminated_by"]
):
kwargs["fields_terminated_by"] = "','"
return BaseData

@classmethod
def pre_load_mods(cls, main_cls, kwargs):
def pre_load_mods(cls, main_cls, kwargs) -> dict:
if kwargs.get("output_mode") == "bigquery":
if kwargs.get("schema"):
kwargs["dataset"] = kwargs["schema"]
Expand All @@ -55,7 +54,7 @@ def pre_load_mods(cls, main_cls, kwargs):
for name, value in cls_fields.items()
if value.metadata and "alias" in value.metadata
}
cls.pre_process_kwargs(kwargs, aliased_fields)
_pre_process_kwargs(kwargs, aliased_fields)
table_main_args = {
k.lower(): v for k, v in kwargs.items() if k.lower() in cls_fields
}
Expand Down
8 changes: 3 additions & 5 deletions simple_ddl_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from ply import lex, yacc

from simple_ddl_parser.exception import DDLParserError
from simple_ddl_parser.output.core import Output, dump_data_to_file
from simple_ddl_parser.output.dialects import dialect_by_name
from simple_ddl_parser.utils import (
SimpleDDLParserException,
find_first_unpair_closed_par,
)
from simple_ddl_parser.utils import find_first_unpair_closed_par

# open comment
OP_COM = "/*"
Expand Down Expand Up @@ -348,7 +346,7 @@ def run(
Dict == one entity from ddl - one table or sequence or type.
"""
if output_mode not in dialect_by_name:
raise SimpleDDLParserException(
Copy link
Author

@demitryfly demitryfly Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, if someone uses this exception in their code (to catch it in try-except, for instance). If, it is possible, then it is more correct to make an alias and deprecate it officially.

raise DDLParserError(
f"Output mode can be one of possible variants: {dialect_by_name.keys()}"
)
self.tables = self.parse_data()
Expand Down
20 changes: 10 additions & 10 deletions simple_ddl_parser/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@


tokens = tuple(
set(
[
{
*[
"ID",
"DOT",
"STRING_BASE",
Expand All @@ -161,14 +161,14 @@
"LT",
"RT",
"COMMAT",
]
+ list(definition_statements.values())
+ list(common_statements.values())
+ list(columns_definition.values())
+ list(sequence_reserved.values())
+ list(after_columns_tokens.values())
+ list(alter_tokens.values())
)
],
*definition_statements.values(),
*common_statements.values(),
*columns_definition.values(),
*sequence_reserved.values(),
*after_columns_tokens.values(),
*alter_tokens.values(),
}
)

symbol_tokens = {
Expand Down
82 changes: 48 additions & 34 deletions simple_ddl_parser/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import re
from typing import List
from typing import List, Tuple, Optional, Union, Any

__all__ = [
"remove_par",
"check_spec",
"find_first_unpair_closed_par",
"normalize_name",
"get_table_id",
]

def remove_par(p_list: List[str]) -> List[str]:
remove_list = ["(", ")"]
for symbol in remove_list:
while symbol in p_list:
p_list.remove(symbol)
_parentheses = ('(', ')')


def remove_par(p_list: List[Union[str, Any]]) -> List[Union[str, Any]]:
"""
Remove the parentheses from the given list

Warn: p_list may contain unhashable types for some unexplored reasons
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to find the reason why p_list may contain dict. Is it expected?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is expected, p_list contains results of parsing statements, usually it something like {'column': name, 'unique': True}, so p_list - always list, but elements of this list can be dicts

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I rewrote this docstring then.

"""
i = j = 0
while i < len(p_list):
if p_list[i] not in _parentheses:
p_list[j] = p_list[i]
j += 1
i += 1
while j < len(p_list):
p_list.pop()
return p_list


Expand All @@ -18,44 +37,39 @@ def remove_par(p_list: List[str]) -> List[str]:
}


# TODO: Add tests
Copy link
Author

@demitryfly demitryfly Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna add some tests for check_spec function in current PR

def check_spec(value: str) -> str:
replace_value = spec_mapper.get(value)
if not replace_value:
for item in spec_mapper:
if item in value:
replace_value = value.replace(item, spec_mapper[item])
break
else:
replace_value = value
return replace_value


def find_first_unpair_closed_par(str_: str) -> int:
stack = []
n = -1
for i in str_:
n += 1
if i == ")":
if not stack:
return n
else:
stack.pop(-1)
elif i == "(":
stack.append(i)
if replace_value:
return replace_value
for item in spec_mapper:
if item in value:
return value.replace(item, spec_mapper[item])
return value


def find_first_unpair_closed_par(str_: str) -> Optional[int]:
count_open = 0
for i, char in enumerate(str_):
if char == '(':
count_open += 1
if char == ')':
count_open -= 1
if count_open < 0:
return i
return None


def normalize_name(name: str) -> str:
# clean up [] and " symbols from names
"""
Clean up [] and " characters from the given name
"""
clean_up_re = r'[\[\]"]'
return re.sub(clean_up_re, "", name).lower()


def get_table_id(schema_name: str, table_name: str):
def get_table_id(schema_name: str, table_name: str) -> Tuple[str, str]:
table_name = normalize_name(table_name)
if schema_name:
schema_name = normalize_name(schema_name)
return (table_name, schema_name)


class SimpleDDLParserException(Exception):
pass
45 changes: 45 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from simple_ddl_parser import utils


@pytest.mark.parametrize(
"expression, expected_result",
[
(")", 0),
(")()", 0),
("())", 2),
("()())", 4),
("", None),
("text", None),
("()", None),
("(balanced) (brackets)", None),
("(not)) (balanced) (brackets", 5)
]
)
def test_find_first_unpair_closed_par(expression, expected_result):
assert utils.find_first_unpair_closed_par(expression) == expected_result


@pytest.mark.parametrize(
"expression, expected_result",
[
([], []),
(["("], []),
([")"], []),
(["(", ")"], []),
([")", "("], []),
(["(", "A"], ["A"]),
(["A", ")"], ["A"]),
(["(", "A", ")"], ["A"]),
(["A", ")", ")"], ["A"]),
(["(", "(", "A"], ["A"]),
(["A", "B", "C"], ["A", "B", "C"]),
(["A", "(", "(", "B", "C", "("], ["A", "B", "C"]),
(["A", ")", "B", ")", "(", "C"], ["A", "B", "C"]),
(["(", "A", ")", "B", "C", ")"], ["A", "B", "C"]),
([dict()], [dict()]), # Edge case (unhashable types)
]
)
def test_remove_par(expression, expected_result):
assert utils.remove_par(expression) == expected_result
Loading