From b853c9ff3b40e2b694b8325a86e511b37d066be8 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Sat, 4 Apr 2026 19:46:11 +0200 Subject: [PATCH] refactor(event_handler): extract OpenAPI schema generation from Route class --- .../event_handler/api_gateway.py | 445 ++------------ .../event_handler/bedrock_agent.py | 17 +- .../event_handler/openapi/constants.py | 2 + .../event_handler/openapi/schema_generator.py | 549 ++++++++++++++++++ 4 files changed, 600 insertions(+), 413 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/openapi/schema_generator.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f7294801460..ad1e14d3122 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,7 +1,6 @@ from __future__ import annotations import base64 -import copy import json import logging import re @@ -23,6 +22,8 @@ from aws_lambda_powertools.event_handler.openapi.config import OpenAPIConfig from aws_lambda_powertools.event_handler.openapi.constants import ( DEFAULT_API_VERSION, + DEFAULT_CONTENT_TYPE, + DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, DEFAULT_OPENAPI_TITLE, DEFAULT_OPENAPI_VERSION, ) @@ -34,13 +35,8 @@ ) from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, - METHODS_WITH_BODY, OpenAPIResponse, - OpenAPIResponseContentModel, - OpenAPIResponseContentSchema, response_validation_error_response_definition, - validation_error_definition, - validation_error_response_definition, ) from aws_lambda_powertools.event_handler.request import Request from aws_lambda_powertools.event_handler.util import ( @@ -73,10 +69,8 @@ # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = r"%<> \[\]{}|^" _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" -_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" _JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder) -_DEFAULT_CONTENT_TYPE = "application/json" ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) ResponseT = TypeVar("ResponseT") @@ -95,7 +89,7 @@ Server, Tag, ) - from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param + from aws_lambda_powertools.event_handler.openapi.params import Dependant from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( OAuth2Config, ) @@ -279,7 +273,7 @@ def __init__( self, body: Any = None, status_code: int = 200, - content_type: str = _DEFAULT_CONTENT_TYPE, + content_type: str = DEFAULT_CONTENT_TYPE, session_attributes: dict[str, Any] | None = None, prompt_session_attributes: dict[str, Any] | None = None, knowledge_bases_configuration: list[dict[str, Any]] | None = None, @@ -355,7 +349,7 @@ def is_json(self) -> bool: content_type = self.headers.get("Content-Type", "") if isinstance(content_type, list): content_type = content_type[0] - return content_type.startswith(_DEFAULT_CONTENT_TYPE) + return content_type.startswith(DEFAULT_CONTENT_TYPE) class Route: @@ -617,7 +611,7 @@ def body_field(self) -> ModelField | None: return self._body_field - def _get_openapi_path( # noqa PLR0912 + def _get_openapi_path( self, *, dependant: Dependant, @@ -628,393 +622,32 @@ def _get_openapi_path( # noqa PLR0912 ) -> tuple[dict[str, Any], dict[str, Any]]: """ Returns the OpenAPI path and definitions for the route. - """ - from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params - - definitions: dict[str, Any] = {} - - # Gather all the route parameters - operation = self._openapi_operation_metadata(operation_ids=operation_ids) - parameters: list[dict[str, Any]] = [] - all_route_params = get_flat_params(dependant) - operation_params = self._openapi_operation_parameters( - all_route_params=all_route_params, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - parameters.extend(operation_params) - - # Add security if present - if self.security: - operation["security"] = self.security - - # Add OpenAPI extensions if present - if self.openapi_extensions: - operation.update(self.openapi_extensions) - - # Add the parameters to the OpenAPI operation - if parameters: - all_parameters = {(param["in"], param["name"]): param for param in parameters} - required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} - all_parameters.update(required_parameters) - operation["parameters"] = list(all_parameters.values()) - - # Add the request body to the OpenAPI operation, if applicable - if self.method.upper() in METHODS_WITH_BODY: - request_body_oai = self._openapi_operation_request_body( - body_field=self.body_field, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - if request_body_oai: - operation["requestBody"] = request_body_oai - - operation_responses: dict[int, OpenAPIResponse] = {} - - if enable_validation: - # Validation failure response (422) is added only if Enable Validation feature is true - operation_responses = { - 422: { - "description": "Validation Error", - "content": { - _DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}, - }, - }, - } - - # Add custom response validation response, if exists - if self.custom_response_validation_http_code: - http_code = self.custom_response_validation_http_code.value - operation_responses[http_code] = { - "description": "Response Validation Error", - "content": { - _DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}, - }, - } - # Add model definition - definitions["ResponseValidationError"] = response_validation_error_response_definition - - # Add the response to the OpenAPI operation - if self.responses: - for status_code in list(self.responses): - # Create a deep copy to prevent mutation of the shared dictionary - response = copy.deepcopy(self.responses[status_code]) - - # Case 1: there is not 'content' key - if "content" not in response: - response["content"] = { - _DEFAULT_CONTENT_TYPE: self._openapi_operation_return( - param=dependant.return_param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ), - } - - # Case 2: there is a 'content' key - else: - # Need to iterate to transform any 'model' into a 'schema' - for content_type, payload in response["content"].items(): - # Case 2.1: the 'content' has a model - if "model" in payload: - # Find the model in the dependant's extra models - model_payload_typed = cast(OpenAPIResponseContentModel, payload) - return_field = next( - filter( - lambda model: model.type_ is model_payload_typed["model"], - self.dependant.response_extra_models, - ), - ) - if not return_field: - raise AssertionError("Model declared in custom responses was not found") - - model_payload = self._openapi_operation_return( - param=return_field, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - - # Preserve existing fields like examples, encoding, etc. - new_payload: OpenAPIResponseContentSchema = {} - for key, value in payload.items(): - if key != "model": - new_payload[key] = value # type: ignore[literal-required] - new_payload.update(model_payload) # Add/override with model schema - - # Case 2.2: the 'content' has a schema - else: - # Do nothing! We already have what we need! - new_payload = cast(OpenAPIResponseContentSchema, payload) - - response["content"][content_type] = new_payload - - # Merge the user provided response with the default responses - operation_responses[status_code] = response - else: - # Set the default 200 response - response_schema = self._openapi_operation_return( - param=dependant.return_param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - - # Add the response schema to the OpenAPI 200 response - operation_responses[200] = { - "description": self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - "content": {_DEFAULT_CONTENT_TYPE: response_schema}, - } - - operation["responses"] = operation_responses - path = {self.method.lower(): operation} - # Add the validation error schema to the definitions, but only if it hasn't been added yet - if "ValidationError" not in definitions: - definitions.update( - { - "ValidationError": validation_error_definition, - "HTTPValidationError": validation_error_response_definition, - }, - ) - - # Generate the response schema - return path, definitions - - def _openapi_operation_summary(self) -> str: - """ - Returns the OpenAPI operation summary. If the user has not provided a summary, we - generate one based on the route path and method. - """ - return self.summary or f"{self.method.upper()} {self.openapi_path}" - - def _openapi_operation_metadata(self, operation_ids: set[str]) -> dict[str, Any]: - """ - Returns the OpenAPI operation metadata. If the user has not provided a description, we - generate one based on the route path and method. - """ - operation: dict[str, Any] = {} - - # Ensure tags is added to the operation - if self.tags: - operation["tags"] = self.tags - - # Ensure summary is added to the operation - operation["summary"] = self._openapi_operation_summary() - - # Ensure description is added to the operation - if self.description: - operation["description"] = self.description - - # Ensure operationId is unique - if self.operation_id in operation_ids: - message = f"Duplicate Operation ID {self.operation_id} for function {self.func.__name__}" - file_name = getattr(self.func, "__globals__", {}).get("__file__") - if file_name: - message += f" in {file_name}" - warnings.warn(message, stacklevel=1) - - # Adds the operation - operation_ids.add(self.operation_id) - operation["operationId"] = self.operation_id - - # Mark as deprecated if necessary - operation["deprecated"] = self.deprecated or None - - return operation - - @staticmethod - def _openapi_operation_request_body( - *, - body_field: ModelField | None, - model_name_map: dict[TypeModelOrEnum, str], - field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], - ) -> dict[str, Any] | None: - """ - Returns the OpenAPI operation request body. - """ - from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field - from aws_lambda_powertools.event_handler.openapi.params import Body - - # Check that there is a body field and it's a Pydantic's model field - if not body_field: - return None - - if not isinstance(body_field, ModelField): - raise AssertionError(f"Expected ModelField, got {body_field}") - - # Generate the request body schema - body_schema = get_schema_from_model_field( - field=body_field, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - - field_info = cast(Body, body_field.field_info) - request_media_type = field_info.media_type - required = body_field.required - request_body_oai: dict[str, Any] = {} - if required: - request_body_oai["required"] = required - - if field_info.description: - request_body_oai["description"] = field_info.description - # Generate the request body media type - request_media_content: dict[str, Any] = {"schema": body_schema} - if field_info.openapi_examples: - request_media_content["examples"] = field_info.openapi_examples - request_body_oai["content"] = {request_media_type: request_media_content} - return request_body_oai - - @staticmethod - def _openapi_operation_parameters( - *, - all_route_params: Sequence[ModelField], - model_name_map: dict[TypeModelOrEnum, str], - field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], - ) -> list[dict[str, Any]]: - """ - Returns the OpenAPI operation parameters. - """ - from aws_lambda_powertools.event_handler.openapi.params import Param - - parameters: list[dict[str, Any]] = [] - - for param in all_route_params: - field_info = cast(Param, param.field_info) - if not field_info.include_in_schema: - continue - - # Check if this is a Pydantic model that should be expanded - if Route._is_pydantic_model_param(field_info): - parameters.extend(Route._expand_pydantic_model_parameters(field_info)) - else: - parameters.append(Route._create_regular_parameter(param, model_name_map, field_mapping)) - - return parameters - - @staticmethod - def _is_pydantic_model_param(field_info: Param) -> bool: - """Check if the field info represents a Pydantic model parameter.""" - from pydantic import BaseModel - - from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass - - return lenient_issubclass(field_info.annotation, BaseModel) - - @staticmethod - def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]: - """Expand a Pydantic model into individual OpenAPI parameters.""" - from pydantic import BaseModel - - model_class = cast(type[BaseModel], field_info.annotation) - parameters: list[dict[str, Any]] = [] - - for field_name, field_def in model_class.model_fields.items(): - param_name = field_def.alias or field_name - individual_param = Route._create_pydantic_field_parameter( - param_name=param_name, - field_def=field_def, - param_location=field_info.in_.value, - ) - parameters.append(individual_param) - - return parameters - - @staticmethod - def _create_pydantic_field_parameter( - param_name: str, - field_def: Any, - param_location: str, - ) -> dict[str, Any]: - """Create an OpenAPI parameter from a Pydantic field definition.""" - individual_param: dict[str, Any] = { - "name": param_name, - "in": param_location, - "required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ..., - "schema": Route._get_basic_type_schema(field_def.annotation or type(None)), - } - - if field_def.description: - individual_param["description"] = field_def.description - - return individual_param - - @staticmethod - def _create_regular_parameter( - param: ModelField, - model_name_map: dict[TypeModelOrEnum, str], - field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], - ) -> dict[str, Any]: - """Create an OpenAPI parameter from a regular ModelField.""" - from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field - from aws_lambda_powertools.event_handler.openapi.params import Param - - field_info = cast(Param, param.field_info) - param_schema = get_schema_from_model_field( - field=param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) - - parameter: dict[str, Any] = { - "name": param.alias, - "in": field_info.in_.value, - "required": param.required, - "schema": param_schema, - } - - # Add optional attributes if present - if field_info.description: - parameter["description"] = field_info.description - if field_info.openapi_examples: - parameter["examples"] = field_info.openapi_examples - if field_info.deprecated: - parameter["deprecated"] = field_info.deprecated - - return parameter - - @staticmethod - def _get_basic_type_schema(param_type: type) -> dict[str, str]: - """ - Get basic OpenAPI schema for simple types - """ - try: - # Check bool before int, since bool is a subclass of int in Python - if issubclass(param_type, bool): - return {"type": "boolean"} - elif issubclass(param_type, int): - return {"type": "integer"} - elif issubclass(param_type, float): - return {"type": "number"} - else: - return {"type": "string"} - except TypeError: - # param_type may not be a type (e.g., typing.Optional[int]), fallback to string - return {"type": "string"} - - @staticmethod - def _openapi_operation_return( - *, - param: ModelField | None, - model_name_map: dict[TypeModelOrEnum, str], - field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], - ) -> OpenAPIResponseContentSchema: - """ - Returns the OpenAPI operation return. - """ - if param is None: - return {} - - from aws_lambda_powertools.event_handler.openapi.compat import ( - get_schema_from_model_field, - ) - - return_schema = get_schema_from_model_field( - field=param, + Delegates to openapi.schema_generator for the actual generation logic. + """ + from aws_lambda_powertools.event_handler.openapi.schema_generator import generate_openapi_path + + return generate_openapi_path( + method=self.method, + operation_id=self.operation_id, + summary=self.summary, + description=self.description, + openapi_path=self.openapi_path, + tags=self.tags, + deprecated=self.deprecated, + security=self.security, + openapi_extensions=self.openapi_extensions, + responses=self.responses, + response_description=self.response_description, + body_field=self.body_field, + custom_response_validation_http_code=self.custom_response_validation_http_code, + dependant=dependant, + operation_ids=operation_ids, model_name_map=model_name_map, field_mapping=field_mapping, + enable_validation=enable_validation, ) - return {"schema": return_schema} - def _generate_operation_id(self) -> str: operation_id = self.func.__name__ + self.openapi_path operation_id = re.sub(r"\W", "_", operation_id) @@ -1155,7 +788,7 @@ def route( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1218,7 +851,7 @@ def get( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1281,7 +914,7 @@ def post( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1345,7 +978,7 @@ def put( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1409,7 +1042,7 @@ def delete( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1472,7 +1105,7 @@ def patch( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1538,7 +1171,7 @@ def head( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -1921,7 +1554,7 @@ def _add_resolver_response_validation_error_response_to_route( response_validation_error_response = { "description": "Response Validation Error", "content": { - _DEFAULT_CONTENT_TYPE: { + DEFAULT_CONTENT_TYPE: { "schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}, }, }, @@ -2622,7 +2255,7 @@ def swagger_handler(): if query_params.get("format") == "json": return Response( status_code=200, - content_type=_DEFAULT_CONTENT_TYPE, + content_type=DEFAULT_CONTENT_TYPE, body=escaped_spec, ) @@ -2674,7 +2307,7 @@ def route( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -3237,7 +2870,7 @@ def route( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str | None = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str | None = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -3355,7 +2988,7 @@ def route( summary: str | None = None, description: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 4593715e88d..e9aa82ee01f 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -7,12 +7,15 @@ from aws_lambda_powertools.event_handler import ApiGatewayResolver from aws_lambda_powertools.event_handler.api_gateway import ( - _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, BedrockResponse, ProxyEventType, ResponseBuilder, ) -from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION +from aws_lambda_powertools.event_handler.openapi.constants import ( + DEFAULT_API_VERSION, + DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + DEFAULT_OPENAPI_VERSION, +) if TYPE_CHECKING: from collections.abc import Callable @@ -118,7 +121,7 @@ def get( # type: ignore[override] cache_control: str | None = None, summary: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -161,7 +164,7 @@ def post( # type: ignore[override] cache_control: str | None = None, summary: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -204,7 +207,7 @@ def put( # type: ignore[override] cache_control: str | None = None, summary: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -247,7 +250,7 @@ def patch( # type: ignore[override] cache_control: str | None = None, summary: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, @@ -290,7 +293,7 @@ def delete( # type: ignore[override] cache_control: str | None = None, summary: str | None = None, responses: dict[int, OpenAPIResponse] | None = None, - response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + response_description: str = DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, diff --git a/aws_lambda_powertools/event_handler/openapi/constants.py b/aws_lambda_powertools/event_handler/openapi/constants.py index debe1d56736..c125e89d0e7 100644 --- a/aws_lambda_powertools/event_handler/openapi/constants.py +++ b/aws_lambda_powertools/event_handler/openapi/constants.py @@ -1,3 +1,5 @@ DEFAULT_API_VERSION = "1.0.0" DEFAULT_OPENAPI_VERSION = "3.1.0" DEFAULT_OPENAPI_TITLE = "Powertools for AWS Lambda (Python) API" +DEFAULT_CONTENT_TYPE = "application/json" +DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" diff --git a/aws_lambda_powertools/event_handler/openapi/schema_generator.py b/aws_lambda_powertools/event_handler/openapi/schema_generator.py new file mode 100644 index 00000000000..5d409693937 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/schema_generator.py @@ -0,0 +1,549 @@ +""" +OpenAPI schema generation for individual routes. + +Extracted from Route to keep route configuration and schema generation +as separate concerns. All functions here are internal. +""" + +from __future__ import annotations + +import copy +import warnings +from typing import TYPE_CHECKING, Any, Literal, cast + +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_PREFIX, + METHODS_WITH_BODY, + OpenAPIResponse, + OpenAPIResponseContentModel, + OpenAPIResponseContentSchema, + response_validation_error_response_definition, + validation_error_definition, + validation_error_response_definition, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + from http import HTTPStatus + + from aws_lambda_powertools.event_handler.openapi.compat import ( + JsonSchemaValue, + ModelField, + ) + from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param + from aws_lambda_powertools.event_handler.openapi.types import TypeModelOrEnum + +from aws_lambda_powertools.event_handler.openapi.constants import ( + DEFAULT_CONTENT_TYPE, + DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, +) + + +def generate_openapi_path( + *, + method: str, + operation_id: str, + summary: str | None, + description: str | None, + openapi_path: str, + tags: list[str], + deprecated: bool, + security: list[dict[str, list[str]]] | None, + openapi_extensions: dict[str, Any] | None, + responses: dict[int, OpenAPIResponse] | None, + response_description: str | None, + body_field: ModelField | None, + custom_response_validation_http_code: HTTPStatus | None, + dependant: Dependant, + operation_ids: set[str], + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + enable_validation: bool = False, +) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Generate the OpenAPI path spec and definitions for a single route. + """ + from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params + + definitions: dict[str, Any] = {} + + # Build operation metadata + operation = _build_operation_metadata( + method=method, + operation_id=operation_id, + summary=summary, + description=description, + openapi_path=openapi_path, + tags=tags, + deprecated=deprecated, + operation_ids=operation_ids, + func_name=dependant.call.__name__ if dependant.call else "", + func_file=getattr(dependant.call, "__globals__", {}).get("__file__") if dependant.call else None, + ) + + _apply_optional_fields(operation, security=security, openapi_extensions=openapi_extensions) + + # Build parameters + all_route_params = get_flat_params(dependant) + parameters = _build_operation_parameters( + all_route_params=all_route_params, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + if parameters: + operation["parameters"] = _deduplicate_parameters(parameters) + + # Build request body + _apply_request_body( + operation, + method=method, + body_field=body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + # Build responses + operation_responses, response_definitions = _build_responses( + responses=responses, + response_description=response_description, + custom_response_validation_http_code=custom_response_validation_http_code, + dependant=dependant, + model_name_map=model_name_map, + field_mapping=field_mapping, + enable_validation=enable_validation, + ) + definitions.update(response_definitions) + + operation["responses"] = operation_responses + path = {method.lower(): operation} + + _add_validation_error_definitions(definitions) + + return path, definitions + + +def _build_operation_metadata( + *, + method: str, + operation_id: str, + summary: str | None, + description: str | None, + openapi_path: str, + tags: list[str], + deprecated: bool, + operation_ids: set[str], + func_name: str, + func_file: str | None, +) -> dict[str, Any]: + """Build the OpenAPI operation metadata (tags, summary, operationId, etc.).""" + _warn_duplicate_operation_id(operation_id, operation_ids, func_name, func_file) + operation_ids.add(operation_id) + + operation: dict[str, Any] = { + "summary": summary or f"{method.upper()} {openapi_path}", + "operationId": operation_id, + "deprecated": deprecated or None, + } + + if tags: + operation["tags"] = tags + if description: + operation["description"] = description + + return operation + + +def _build_operation_parameters( + *, + all_route_params: Sequence[ModelField], + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> list[dict[str, Any]]: + """Build the list of OpenAPI operation parameters.""" + from aws_lambda_powertools.event_handler.openapi.params import Param + + parameters: list[dict[str, Any]] = [] + + for param in all_route_params: + field_info = cast(Param, param.field_info) + if not field_info.include_in_schema: + continue + + if _is_pydantic_model_param(field_info): + parameters.extend(_expand_pydantic_model_parameters(field_info)) + else: + parameters.append(_create_regular_parameter(param, model_name_map, field_mapping)) + + return parameters + + +def _build_request_body( + *, + body_field: ModelField | None, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> dict[str, Any] | None: + """Build the OpenAPI request body spec.""" + from aws_lambda_powertools.event_handler.openapi.compat import ModelField as ModelFieldClass + from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Body + + if not body_field: + return None + + if not isinstance(body_field, ModelFieldClass): + raise AssertionError(f"Expected ModelField, got {body_field}") + + body_schema = get_schema_from_model_field( + field=body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + field_info = cast(Body, body_field.field_info) + + request_body_oai: dict[str, Any] = {} + if body_field.required: + request_body_oai["required"] = body_field.required + if field_info.description: + request_body_oai["description"] = field_info.description + + request_body_oai["content"] = { + field_info.media_type: _build_media_content(body_schema, field_info.openapi_examples), + } + return request_body_oai + + +def _build_responses( + *, + responses: dict[int, OpenAPIResponse] | None, + response_description: str | None, + custom_response_validation_http_code: HTTPStatus | None, + dependant: Dependant, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + enable_validation: bool, +) -> tuple[dict[int, OpenAPIResponse], dict[str, Any]]: + """Build the OpenAPI response specs and any extra definitions.""" + definitions: dict[str, Any] = {} + operation_responses: dict[int, OpenAPIResponse] = {} + + _add_validation_responses(operation_responses, enable_validation=enable_validation) + _add_response_validation_error( + operation_responses, + definitions, + custom_response_validation_http_code=custom_response_validation_http_code, + ) + + if responses: + for status_code in list(responses): + operation_responses[status_code] = _build_custom_response( + response=copy.deepcopy(responses[status_code]), + dependant=dependant, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + else: + response_schema = _build_return_schema( + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + operation_responses[200] = { + "description": response_description or DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + "content": {DEFAULT_CONTENT_TYPE: response_schema}, + } + + return operation_responses, definitions + + +def _build_return_schema( + *, + param: ModelField | None, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> OpenAPIResponseContentSchema: + """Build the response schema for a return parameter.""" + if param is None: + return {} + + from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field + + return_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + return {"schema": return_schema} + + +def _is_pydantic_model_param(field_info: Param) -> bool: + """Check if the field info represents a Pydantic model parameter.""" + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass + + return lenient_issubclass(field_info.annotation, BaseModel) + + +def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]: + """Expand a Pydantic model into individual OpenAPI parameters.""" + from pydantic import BaseModel + + model_class = cast(type[BaseModel], field_info.annotation) + parameters: list[dict[str, Any]] = [] + + for field_name, field_def in model_class.model_fields.items(): + param_name = field_def.alias or field_name + individual_param = _create_pydantic_field_parameter( + param_name=param_name, + field_def=field_def, + param_location=field_info.in_.value, + ) + parameters.append(individual_param) + + return parameters + + +def _create_pydantic_field_parameter( + param_name: str, + field_def: Any, + param_location: str, +) -> dict[str, Any]: + """Create an OpenAPI parameter from a Pydantic field definition.""" + individual_param: dict[str, Any] = { + "name": param_name, + "in": param_location, + "required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ..., + "schema": _get_basic_type_schema(field_def.annotation or type(None)), + } + + if field_def.description: + individual_param["description"] = field_def.description + + return individual_param + + +def _create_regular_parameter( + param: ModelField, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> dict[str, Any]: + """Create an OpenAPI parameter from a regular ModelField.""" + from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Param + + field_info = cast(Param, param.field_info) + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + parameter: dict[str, Any] = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + if field_info.description: + parameter["description"] = field_info.description + if field_info.openapi_examples: + parameter["examples"] = field_info.openapi_examples + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + return parameter + + +def _get_basic_type_schema(param_type: type) -> dict[str, str]: + """Get basic OpenAPI schema for simple types.""" + type_map: dict[type, str] = {bool: "boolean", int: "integer", float: "number"} + try: + for base_type, schema_type in type_map.items(): + if issubclass(param_type, base_type): + return {"type": schema_type} + return {"type": "string"} + except TypeError: + return {"type": "string"} + + +def _apply_optional_fields( + operation: dict[str, Any], + *, + security: list[dict[str, list[str]]] | None, + openapi_extensions: dict[str, Any] | None, +) -> None: + """Apply optional security and extension fields to the operation.""" + if security: + operation["security"] = security + if openapi_extensions: + operation.update(openapi_extensions) + + +def _apply_request_body( + operation: dict[str, Any], + *, + method: str, + body_field: ModelField | None, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> None: + """Build and apply request body to operation if applicable.""" + if method.upper() not in METHODS_WITH_BODY: + return + + request_body_oai = _build_request_body( + body_field=body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + if request_body_oai: + operation["requestBody"] = request_body_oai + + +def _add_validation_responses( + operation_responses: dict[int, OpenAPIResponse], + *, + enable_validation: bool, +) -> None: + """Add 422 validation error response if validation is enabled.""" + if not enable_validation: + return + + operation_responses[422] = { + "description": "Validation Error", + "content": { + DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}, + }, + } + + +def _add_response_validation_error( + operation_responses: dict[int, OpenAPIResponse], + definitions: dict[str, Any], + *, + custom_response_validation_http_code: HTTPStatus | None, +) -> None: + """Add response validation error if a custom HTTP code is configured.""" + if not custom_response_validation_http_code: + return + + http_code = custom_response_validation_http_code.value + operation_responses[http_code] = { + "description": "Response Validation Error", + "content": { + DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}, + }, + } + definitions["ResponseValidationError"] = response_validation_error_response_definition + + +def _deduplicate_parameters(parameters: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Deduplicate parameters, giving priority to required ones.""" + all_parameters = {(param["in"], param["name"]): param for param in parameters} + required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} + all_parameters.update(required_parameters) + return list(all_parameters.values()) + + +def _add_validation_error_definitions(definitions: dict[str, Any]) -> None: + """Add standard validation error schema definitions if not already present.""" + if "ValidationError" not in definitions: + definitions["ValidationError"] = validation_error_definition + definitions["HTTPValidationError"] = validation_error_response_definition + + +def _warn_duplicate_operation_id( + operation_id: str, + operation_ids: set[str], + func_name: str, + func_file: str | None, +) -> None: + """Warn if an operationId has already been used.""" + if operation_id not in operation_ids: + return + + message = f"Duplicate Operation ID {operation_id} for function {func_name}" + if func_file: + message += f" in {func_file}" + warnings.warn(message, stacklevel=1) + + +def _build_media_content( + body_schema: dict[str, Any], + openapi_examples: dict[str, Any] | None, +) -> dict[str, Any]: + """Build the media content dict for a request body.""" + content: dict[str, Any] = {"schema": body_schema} + if openapi_examples: + content["examples"] = openapi_examples + return content + + +def _build_custom_response( + *, + response: OpenAPIResponse, + dependant: Dependant, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> OpenAPIResponse: + """Build a single custom response, resolving model references in content.""" + if "content" not in response: + response["content"] = { + DEFAULT_CONTENT_TYPE: _build_return_schema( + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), + } + return response + + for content_type, payload in response["content"].items(): + response["content"][content_type] = _resolve_response_payload( + payload=payload, + dependant=dependant, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + return response + + +def _resolve_response_payload( + *, + payload: OpenAPIResponseContentSchema | OpenAPIResponseContentModel, + dependant: Dependant, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], +) -> OpenAPIResponseContentSchema: + """Resolve a single response content payload, replacing model refs with schemas.""" + if "model" not in payload: + return cast(OpenAPIResponseContentSchema, payload) + + model_payload_typed = cast(OpenAPIResponseContentModel, payload) + return_field = next( + filter( + lambda model: model.type_ is model_payload_typed["model"], + dependant.response_extra_models, + ), + ) + if not return_field: + raise AssertionError("Model declared in custom responses was not found") + + model_payload = _build_return_schema( + param=return_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + new_payload: OpenAPIResponseContentSchema = {} + for key, value in payload.items(): + if key != "model": + new_payload[key] = value # type: ignore[literal-required] + new_payload.update(model_payload) + return new_payload