Skip to content

Commit

Permalink
Merge pull request #13 from verkada/jackkeane/fix_qualname_parsing
Browse files Browse the repository at this point in the history
Fix testing errors
  • Loading branch information
jakeane committed Nov 14, 2023
2 parents 710529f + 8b77254 commit c98952e
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 80 deletions.
38 changes: 29 additions & 9 deletions flask_pydantic_spec/flask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def __init__(self, validator: Any) -> None:

def find_routes(self) -> Any:
for rule in self.app.url_map.iter_rules():
if any(str(rule).startswith(path) for path in (f"/{self.config.PATH}", "/static")):
if any(
str(rule).startswith(path)
for path in (f"/{self.config.PATH}", "/static")
):
continue
yield rule

Expand All @@ -55,7 +58,6 @@ def parse_func(self, route: Any) -> Any:
yield method, func

def parse_path(self, route: Rule) -> Tuple[str, List[Any]]:

subs = []
parameters = []

Expand All @@ -75,7 +77,10 @@ def parse_path(self, route: Rule) -> Tuple[str, List[Any]]:
if converter == "any":
schema = {
"type": "array",
"items": {"type": "string", "enum": args,},
"items": {
"type": "string",
"enum": args,
},
}
elif converter == "int":
schema = {
Expand Down Expand Up @@ -112,7 +117,12 @@ def parse_path(self, route: Rule) -> Tuple[str, List[Any]]:
schema = {"type": "string"}

parameters.append(
{"name": variable, "in": "path", "required": True, "schema": schema,}
{
"name": variable,
"in": "path",
"required": True,
"schema": schema,
}
)

return "".join(subs), parameters
Expand All @@ -132,10 +142,14 @@ def request_validation(
req_query = {}
if request.content_type and "application/json" in request.content_type:
if request.content_encoding and "gzip" in request.content_encoding:
raw_body = gzip.decompress(request.stream.read()).decode(encoding="utf-8")
raw_body = gzip.decompress(request.stream.read()).decode(
encoding="utf-8"
)
parsed_body = json.loads(raw_body)
else:
parsed_body = {} if request.get_data() == b"" else request.get_json(force=True)
parsed_body = (
{} if request.get_data() == b"" else request.get_json(force=True)
)
elif request.content_type and "multipart/form-data" in request.content_type:
parsed_body = parse_multi_dict(request.form) if request.form else {}
else:
Expand Down Expand Up @@ -173,7 +187,9 @@ def validate(
self.request_validation(request, query, body, headers, cookies)
except ValidationError as err:
req_validation_error = err
response = make_response(jsonify(err.errors()), self.config.VALIDATION_ERROR_CODE)
response = make_response(
jsonify(err.errors()), self.config.VALIDATION_ERROR_CODE
)

before(request, response, req_validation_error, None)
if req_validation_error:
Expand All @@ -188,7 +204,9 @@ def validate(
model.validate(response.get_json())
except ValidationError as err:
resp_validation_error = err
response = make_response(jsonify({"message": "response validation error"}), 500)
response = make_response(
jsonify({"message": "response validation error"}), 500
)

after(request, response, resp_validation_error, None)

Expand All @@ -199,7 +217,9 @@ def register_route(self, app: Flask) -> None:
from flask import jsonify

self.app.add_url_rule(
self.config.spec_url, "openapi", lambda: jsonify(self.validator.spec),
self.config.spec_url,
"openapi",
lambda: jsonify(self.validator.spec),
)

for ui in PAGES:
Expand Down
91 changes: 61 additions & 30 deletions flask_pydantic_spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,20 @@ def __init__(
self.models: Dict[str, Any] = {}
if app:
self.register(app)
self.class_view_api_info = dict() # class view info when adding validate decorator
self.class_view_apispec = dict() # convert class_view_api_info into openapi spec
self.routes_by_category = dict() # routes openapi info by category as key in the dict
self._spec_by_category = dict() # openapi spec by category
self._models_by_category = defaultdict(dict) # model schemas by category
self.tags = dict()
self.class_view_api_info: Dict[
str, dict
] = dict() # class view info when adding validate decorator
self.class_view_apispec: Dict[
str, dict
] = dict() # convert class_view_api_info into openapi spec
self.routes_by_category: Dict[
str, dict
] = dict() # routes openapi info by category as key in the dict
self._spec_by_category: Dict[str, Mapping] = dict() # openapi spec by category
self._models_by_category: Dict[str, dict] = defaultdict(
dict
) # model schemas by category
self.tags: Dict[str, Mapping] = dict()

def register(self, app: Flask) -> None:
"""
Expand All @@ -88,7 +96,7 @@ def spec(self) -> Mapping[str, Any]:
self._spec = self._generate_spec()
return self._spec

def spec_by_category(self, category) -> Mapping[str, Any]:
def spec_by_category(self, category: str) -> Mapping[str, Any]:
"""
get OpenAPI spec by category
:return:
Expand Down Expand Up @@ -124,9 +132,10 @@ def bypass(self, func: Callable) -> bool:
return False

def bypass_unpublish(self, func: Callable) -> bool:
""" bypass unpublished APIs under publish_only mode"""
"""bypass unpublished APIs under publish_only mode"""
if self.config.MODE == "publish_only":
return not getattr(func, "publish", False)
return False

def validate(
self,
Expand Down Expand Up @@ -181,7 +190,7 @@ def sync_validate(*args: Any, **kwargs: Any) -> FlaskResponse:
params = []
if "." in func.__qualname__:
class_view = True
view_name, method = func.__qualname__.split(".")
view_name, *_, method = func.__qualname__.split(".")
if view_name not in self.class_view_api_info:
self.class_view_api_info[view_name] = {method: {}}
else:
Expand All @@ -205,21 +214,24 @@ def sync_validate(*args: Any, **kwargs: Any) -> FlaskResponse:
else:
_model = model
if _model:
self.models[_model.__name__] = self._get_open_api_schema(_model.schema())
self.models[_model.__name__] = self._get_open_api_schema(
_model.schema()
)
self._models_by_category[category][
_model.__name__
] = self._get_open_api_schema(_model.schema())
setattr(validation, name, model)

if class_view:
if class_view and _model:
model_schema = self._get_open_api_schema(_model.schema())
for param_name, schema in model_schema["properties"].items():
params.append(
{
"name": param_name,
"in": name,
"schema": schema,
"required": param_name in model_schema.get("required", []),
"required": param_name
in model_schema.get("required", []),
}
)

Expand All @@ -228,21 +240,25 @@ def sync_validate(*args: Any, **kwargs: Any) -> FlaskResponse:
param for param in params if param["in"] == "query"
]
if hasattr(validation, "body"):
self.class_view_api_info[view_name][method]["requestBody"] = parse_request(
validation
)
self.class_view_api_info[view_name][method][
"requestBody"
] = parse_request(validation)

if resp:
for model in resp.models:
if model:
assert not isinstance(model, RequestBase)
self.models[model.__name__] = self._get_open_api_schema(model.schema())
self.models[model.__name__] = self._get_open_api_schema(
model.schema()
)
self._models_by_category[category][
model.__name__
] = self._get_open_api_schema(model.schema())
if class_view:
for k, v in resp.generate_spec().items():
self.class_view_api_info[view_name][method]["responses"][k] = v
self.class_view_api_info[view_name][method][
"responses"
][k] = v
setattr(validation, "resp", resp)

if tags:
Expand All @@ -260,12 +276,17 @@ def sync_validate(*args: Any, **kwargs: Any) -> FlaskResponse:

return decorate_validation

def _generate_spec_common(self, routes, category=None):
spec = {
def _generate_spec_common(
self, routes: dict, category: Optional[str] = None
) -> dict:
spec: Dict[str, Any] = {
"openapi": self.config.OPENAPI_VERSION,
"info": {
**self.config.INFO,
**{"title": self.config.TITLE, "version": self.config.VERSION,},
**{
"title": self.config.TITLE,
"version": self.config.VERSION,
},
},
"tags": list(self.tags.values()),
"paths": {**routes},
Expand Down Expand Up @@ -321,10 +342,16 @@ def _generate_spec(self) -> Mapping[str, Any]:
operation_id = camelize(method.lower() + name, False)
desc = self.class_view_apispec[path][method.lower()]["description"]
func_tag = self.class_view_apispec[path][method.lower()]["tags"]
query_parameters = self.class_view_apispec[path][method.lower()]["parameters"]
path_parameters = [param for param in parameters if param["in"] == "path"]
query_parameters = self.class_view_apispec[path][method.lower()][
"parameters"
]
path_parameters = [
param for param in parameters if param["in"] == "path"
]
parameters = path_parameters + query_parameters
responses = self.class_view_apispec[path][method.lower()]["responses"]
responses = self.class_view_apispec[path][method.lower()][
"responses"
]
request_body = self.class_view_apispec[path][method.lower()].get(
"requestBody", None
)
Expand Down Expand Up @@ -368,9 +395,9 @@ def _generate_spec(self) -> Mapping[str, Any]:
routes[path][method.lower()]["deprecated"] = True

if request_body:
routes[path][method.lower()]["requestBody"] = self._parse_request_body(
request_body
)
routes[path][method.lower()][
"requestBody"
] = self._parse_request_body(request_body)
self.routes_by_category[category][path][method.lower()][
"requestBody"
] = self._parse_request_body(request_body)
Expand Down Expand Up @@ -444,9 +471,11 @@ def _get_open_api_schema(self, schema: Mapping[str, Any]) -> Mapping[str, Any]:
result[key] = self._validate_property(value)
else:
result[key] = value
return cast(Mapping[str, Any], nested_alter(result, "$ref", _move_schema_reference))
return cast(
Mapping[str, Any], nested_alter(result, "$ref", _move_schema_reference)
)

def _get_model_definitions(self, category=None) -> Dict[str, Any]:
def _get_model_definitions(self, category: Optional[str] = None) -> Dict[str, Any]:
"""
handle nested models
"""
Expand Down Expand Up @@ -476,11 +505,13 @@ def _parse_request_body(self, request_body: Mapping[str, Any]) -> Mapping[str, A
schema = request_body["content"][content_type]["schema"]
if "$ref" not in schema.keys():
# handle inline schema definitions
return {"content": {content_type: {"schema": self._get_open_api_schema(schema)}}}
return {
"content": {content_type: {"schema": self._get_open_api_schema(schema)}}
}
else:
return request_body

def register_class_view_apidoc(self, target):
def register_class_view_apidoc(self, target: Any) -> None:
endpoint = target.__name__
rules = self.app.url_map._rules_by_endpoint[endpoint]
for rule in rules:
Expand Down
32 changes: 24 additions & 8 deletions flask_pydantic_spec/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Response(ResponseBase):
"""

def __init__(self, *args: Any, **kwargs: Any):

self.validate = True
self.codes = []
for item in args:
Expand Down Expand Up @@ -117,17 +116,22 @@ def generate_spec(self) -> Mapping[str, Any]:
responses = {
"200": {
"description": DEFAULT_CODE_DESC["HTTP_200"],
"content": {self.content_type: {"schema": {"type": "string", "format": "binary"}}},
"content": {
self.content_type: {
"schema": {"type": "string", "format": "binary"}
}
},
},
"404": {"description": DEFAULT_CODE_DESC["HTTP_404"]},
}

return responses


class HLSFileResponse(ResponseBase):
def __init__(self, content_types: List[str] = None):
def __init__(self, content_types: Optional[List[str]] = None):
self.content_types = content_types

def has_model(self) -> bool:
"""
HLS file response cannot have a model
Expand All @@ -142,13 +146,18 @@ def generate_spec(self) -> Mapping[str, Any]:
responses = {
"200": {
"description": DEFAULT_CODE_DESC["HTTP_200"],
"content": {self.content_types: {"schema": {"type": "string", "format": "binary"}}},
"content": {
self.content_types: {
"schema": {"type": "string", "format": "binary"}
}
},
},
"404": {"description": DEFAULT_CODE_DESC["HTTP_404"]},
}

return responses


class RequestBase:
def has_model(self) -> bool:
raise NotImplemented
Expand All @@ -175,15 +184,19 @@ def generate_spec(self) -> Mapping[str, Any]:
if self.content_type == "application/octet-stream":
return {
"content": {
self.content_type: {"schema": {"type": "string", "format": self.encoding}}
self.content_type: {
"schema": {"type": "string", "format": self.encoding}
}
}
}
else:
assert self.model is not None
return {
"content": {
self.content_type: {
"schema": {"$ref": f"#/components/schemas/{self.model.__name__}"}
"schema": {
"$ref": f"#/components/schemas/{self.model.__name__}"
}
}
}
}
Expand Down Expand Up @@ -218,7 +231,10 @@ def generate_spec(self) -> Mapping[str, Any]:
"type": "object",
"properties": {
**additional_properties,
self.file_key: {"type": "string", "format": self.encoding,},
self.file_key: {
"type": "string",
"format": self.encoding,
},
},
}
}
Expand Down
Loading

0 comments on commit c98952e

Please sign in to comment.