diff --git a/doc/changelog.rst b/doc/changelog.rst index 4ba907a..645686f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,23 @@ Changelog ========= +[Unreleased] +------------ + +Performance +^^^^^^^^^^^ +- Cached commonly used metadata of fields to :attr:`~scim2_models.BaseModel.__scim_info__`. +- Collapsed all scim context validators in :class:`~scim2_models.BaseModel` to one model validator. +- Collapsed serialization to one model serializer in :class:`~scim2_models.BaseModel`. +- Moved ``model_dump`` and ``model_dump_json`` to :class:`~scim2_models.BaseModel`. +- Cached ``_normalize_attribute_name``. +- Simplified ``normalize_attribute_names`` + +Fixed +^^^^^ +- Check recursively extensions' replace constraints. +- The result of ``model_dump`` does not differ from native pydantic's dump if ``scim_ctx`` is set to ``None``. + [0.6.12] - 2026-04-13 --------------------- diff --git a/scim2_models/attributes.py b/scim2_models/attributes.py index 3764d12..ea59321 100644 --- a/scim2_models/attributes.py +++ b/scim2_models/attributes.py @@ -1,6 +1,7 @@ from inspect import isclass from typing import Annotated from typing import Any +from typing import ClassVar from typing import get_origin from pydantic import Field @@ -15,6 +16,8 @@ class ComplexAttribute(BaseModel): """A complex attribute as defined in :rfc:`RFC7643 §2.3.8 <7643#section-2.3.8>`.""" + __is_complex_attribute__: ClassVar[bool] = True + _attribute_urn: str | None = None def get_attribute_urn(self, field_name: str) -> str: diff --git a/scim2_models/base.py b/scim2_models/base.py index d427673..223ac25 100644 --- a/scim2_models/base.py +++ b/scim2_models/base.py @@ -1,20 +1,21 @@ import warnings from inspect import isclass +from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar +from typing import NamedTuple from typing import Optional +from typing import cast from typing import get_args from typing import get_origin from pydantic import AliasGenerator from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict -from pydantic import FieldSerializationInfo from pydantic import SerializationInfo from pydantic import SerializerFunctionWrapHandler from pydantic import ValidationInfo from pydantic import ValidatorFunctionWrapHandler -from pydantic import field_serializer -from pydantic import field_validator from pydantic import model_serializer from pydantic import model_validator from pydantic_core import PydanticCustomError @@ -26,10 +27,12 @@ from scim2_models.context import Context from scim2_models.exceptions import MutabilityException from scim2_models.utils import UNION_TYPES -from scim2_models.utils import _find_field_name from scim2_models.utils import _normalize_attribute_name from scim2_models.utils import _to_camel +if TYPE_CHECKING: + from scim2_models.path import Path + def _short_attr_path(urn: str) -> str: """Extract the short attribute path from a full URN. @@ -98,6 +101,25 @@ def _is_attribute_requested(requested_attrs: list[str], current_urn: str) -> boo return any(_attr_matches(req, current_urn) for req in requested_attrs) +class _SCIMClassInfo(NamedTuple): + """SCIM metadata for BaseModel.""" + + alias_to_field: dict[str, str] = {} + """Alias -> Python field name. + + Holds both validation and serialization aliases. + """ + + attribute_urns: dict[str, str] = {} + """Python field name -> fully resolved SCIM attribute URN.""" + + complex_fields: frozenset[str] = frozenset() + """Field names whose root type is a ``ComplexAttribute`` subclass.""" + + extensions: frozenset[str] = frozenset() + """Field names whose root type is a ``Extension`` subclass.""" + + class BaseModel(PydanticBaseModel): """Base Model for everything.""" @@ -112,6 +134,9 @@ class BaseModel(PydanticBaseModel): extra="forbid", ) + __scim_info__: ClassVar[_SCIMClassInfo] = _SCIMClassInfo() + """Cached model metadata""" + @classmethod def get_field_annotation(cls, field_name: str, annotation_type: type) -> Any: """Return the annotation of type 'annotation_type' of the field 'field_name'. @@ -226,46 +251,61 @@ def get_field_multiplicity(cls, attribute_name: str) -> bool: origin = get_origin(attribute_type) return isinstance(origin, type) and issubclass(origin, list) - @field_validator("*") @classmethod - def check_request_attributes_mutability( - cls, value: Any, info: ValidationInfo - ) -> Any: - """Check and fix that the field mutability is expected according to the requests validation context, as defined in :rfc:`RFC7643 §7 <7643#section-7>`.""" - if ( - not info.context - or not info.field_name - or not info.context.get("scim") - or not Context.is_request(info.context["scim"]) - ): - return value - - context = info.context.get("scim") - mutability = cls.get_field_annotation(info.field_name, Mutability) - exc = PydanticCustomError( - "mutability_error", - "Field '{field_name}' has mutability '{field_mutability}' but this in not valid in {context} context", - { - "field_name": info.field_name, - "field_mutability": mutability, - "context": context.name.lower().replace("_", " "), - }, - ) + def __pydantic_on_complete__(cls) -> None: + """Build the per-class SCIM metadata table on ``cls.__scim_info__``. - if ( - context in (Context.RESOURCE_QUERY_REQUEST, Context.SEARCH_REQUEST) - and mutability == Mutability.write_only - ): - raise exc + Fires after pydantic resolves field types (re-fires after ``model_rebuild``). Idempotent. + """ + if not cls.model_fields: + return - if ( - context - in (Context.RESOURCE_CREATION_REQUEST, Context.RESOURCE_REPLACEMENT_REQUEST) - and mutability == Mutability.read_only - ): - return None + alias_to_field: dict[str, str] = {} + attribute_urns: dict[str, str] = {} + complex_fields: set[str] = set() + extensions: set[str] = set() - return value + main_schema = getattr(cls, "__schema__", None) + extension_cls: type | None = None + if main_schema is not None: + from scim2_models.resources.resource import Extension + + extension_cls = Extension + + for field_name, field in cls.model_fields.items(): + # Alias -> field name mapping + serialization_alias = field.serialization_alias or field_name + alias_to_field[serialization_alias] = field_name + alias_to_field[cast(str, field.validation_alias)] = field_name + + root_type = cls.get_field_root_type(field_name) + + # Is complex field + if root_type is not None and getattr( + root_type, "__is_complex_attribute__", False + ): + complex_fields.add(field_name) + + # Is extension + if ( + extension_cls is not None + and isclass(root_type) + and issubclass(root_type, extension_cls) + ): + extensions.add(field_name) + + # Attribute URNs + if main_schema is not None and field_name not in extensions: + attribute_urns[field_name] = f"{main_schema}:{serialization_alias}" + else: + attribute_urns[field_name] = serialization_alias + + cls.__scim_info__ = _SCIMClassInfo( + alias_to_field=alias_to_field, + attribute_urns=attribute_urns, + complex_fields=frozenset(complex_fields), + extensions=frozenset(extensions), + ) @model_validator(mode="wrap") @classmethod @@ -278,189 +318,143 @@ def normalize_attribute_names( names should be case-insensitive. Any attribute name is transformed in lowercase so any case is handled the same way. """ + if isinstance(value, dict): + value = {_normalize_attribute_name(k): v for k, v in value.items()} + return cast(Self, handler(value)) - def normalize_dict_keys( - input_dict: dict[str, Any], model_class: type["BaseModel"] - ) -> dict[str, Any]: - """Normalize dictionary keys, preserving case for Any fields.""" - result = {} + @model_validator(mode="after") + def enforce_scim_context(self, info: ValidationInfo) -> Self: + scim_context = info.context.get("scim") if info.context else None + if not scim_context or scim_context == Context.DEFAULT: + return self - for key, val in input_dict.items(): - field_name = _find_field_name(model_class, key) - field_type = ( - model_class.get_field_root_type(field_name) if field_name else None - ) + from scim2_models.resources.resource import Resource - # Don't normalize keys for attributes typed with Any - # This way, agnostic dicts such as PatchOp.operations.value - # are preserved - if field_name and field_type == Any: - result[key] = normalize_value(val) - else: - result[_normalize_attribute_name(key)] = normalize_value( - val, field_type - ) + is_create_or_replace = scim_context in ( + Context.RESOURCE_CREATION_REQUEST, + Context.RESOURCE_REPLACEMENT_REQUEST, + ) + original = info.context.get("original") if info.context else None + fields_set = self.model_fields_set - return result + for field_name in self.__class__.model_fields: + value = getattr(self, field_name) - def normalize_value( - val: Any, model_class: type["BaseModel"] | None = None - ) -> Any: - """Normalize input value based on model class.""" - if not isinstance(val, dict): - return val + if Context.is_request(scim_context): + if field_name in fields_set: + self._check_mutability(field_name, scim_context) + if is_create_or_replace: + self._check_necessity(field_name, value) + else: + # Must be response + self._check_returnability(field_name, value) - # If no model_class, preserve original keys - if not model_class: - return {k: normalize_value(v) for k, v in val.items()} + if self.get_field_multiplicity(field_name) and value is not None: + self._check_primary_uniqueness(field_name, value) - return normalize_dict_keys(val, model_class) + # DEPRECATED: Remove when original is not used in validation + if ( + scim_context == Context.RESOURCE_REPLACEMENT_REQUEST + and original is not None + and issubclass(type(self), Resource) + ): + self._check_replacement_mutability(original) - normalized_value = normalize_value(value, cls) - obj = handler(normalized_value) - assert isinstance(obj, cls) - return obj + return self - @model_validator(mode="wrap") - @classmethod - def check_response_attributes_returnability( - cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo - ) -> Self: - """Check that the fields returnability is expected according to the responses validation context, as defined in :rfc:`RFC7643 §7 <7643#section-7>`.""" - obj = handler(value) - assert isinstance(obj, cls) + def _check_mutability(self, field_name: str, scim_context: Context) -> None: + """Check and fix that the field mutability is expected according to the requests validation context, as defined in :rfc:`RFC7643 §7 <7643#section-7>`.""" + mutability = self.__class__.get_field_annotation(field_name, Mutability) if ( - not info.context - or not info.context.get("scim") - or not Context.is_response(info.context["scim"]) + scim_context in (Context.RESOURCE_QUERY_REQUEST, Context.SEARCH_REQUEST) + and mutability == Mutability.write_only ): - return obj - - for field_name in cls.model_fields: - returnability = cls.get_field_annotation(field_name, Returned) - - if returnability == Returned.always and getattr(obj, field_name) is None: - raise PydanticCustomError( - "returned_error", - "Field '{field_name}' has returnability 'always' but value is missing or null", - { - "field_name": field_name, - }, - ) - - if returnability == Returned.never and getattr(obj, field_name) is not None: - raise PydanticCustomError( - "returned_error", - "Field '{field_name}' has returnability 'never' but value is set", - { - "field_name": field_name, - }, - ) + raise PydanticCustomError( + "mutability_error", + "Field '{field_name}' has mutability '{field_mutability}' but this in not valid in {context} context", + { + "field_name": field_name, + "field_mutability": mutability, + "context": scim_context.name.lower().replace("_", " "), + }, + ) - return obj + elif ( + scim_context + in (Context.RESOURCE_CREATION_REQUEST, Context.RESOURCE_REPLACEMENT_REQUEST) + and mutability == Mutability.read_only + ): + # Avoid re-triggering this validation by using __dict__ + self.__dict__[field_name] = None - @model_validator(mode="wrap") - @classmethod - def check_response_attributes_necessity( - cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo - ) -> Self: + def _check_necessity(self, field_name: str, value: Any) -> None: """Check that the required attributes are present in creations and replacement requests.""" - obj = handler(value) - assert isinstance(obj, cls) + necessity = self.__class__.get_field_annotation(field_name, Required) + + if necessity == Required.true and value is None: + raise PydanticCustomError( + "required_error", + "Field '{field_name}' is required but value is missing or null", + { + "field_name": field_name, + }, + ) - if ( - not info.context - or not info.context.get("scim") - or info.context["scim"] - not in ( - Context.RESOURCE_CREATION_REQUEST, - Context.RESOURCE_REPLACEMENT_REQUEST, + def _check_returnability(self, field_name: str, value: Any) -> None: + """Check that the fields returnability is expected according to the responses validation context, as defined in :rfc:`RFC7643 §7 <7643#section-7>`.""" + returnability = self.__class__.get_field_annotation(field_name, Returned) + + if returnability == Returned.always and value is None: + raise PydanticCustomError( + "returned_error", + "Field '{field_name}' has returnability 'always' but value is missing or null", + { + "field_name": field_name, + }, ) - ): - return obj - - for field_name in cls.model_fields: - necessity = cls.get_field_annotation(field_name, Required) - - if necessity == Required.true and getattr(obj, field_name) is None: - raise PydanticCustomError( - "required_error", - "Field '{field_name}' is required but value is missing or null", - { - "field_name": field_name, - }, - ) - return obj + elif returnability == Returned.never and value is not None: + raise PydanticCustomError( + "returned_error", + "Field '{field_name}' has returnability 'never' but value is set", + { + "field_name": field_name, + }, + ) - @model_validator(mode="wrap") - @classmethod - def check_replacement_request_mutability( - cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo - ) -> Self: + def _check_replacement_mutability(self, original: "BaseModel") -> None: """Check if 'immutable' attributes have been mutated in replacement requests.""" - from scim2_models.resources.resource import Resource - - obj = handler(value) - assert isinstance(obj, cls) - - context = info.context.get("scim") if info.context else None - original = info.context.get("original") if info.context else None + try: + self._apply_replace_constraints(original) + except MutabilityException as exc: + raise exc.as_pydantic_error() from exc + + def _check_primary_uniqueness(self, field_name: str, value: Any) -> None: + """Validate that only one attribute can be marked as primary in multi-valued lists, per :rfc:`RFC7643 §2.4 <7643#section-2.4>`.""" + element_type = self.get_field_root_type(field_name) if ( - context == Context.RESOURCE_REPLACEMENT_REQUEST - and issubclass(cls, Resource) - and original is not None + element_type is None + or not isclass(element_type) + or not issubclass(element_type, PydanticBaseModel) + or "primary" not in element_type.model_fields ): - try: - obj._apply_replace_constraints(original) - except MutabilityException as exc: - raise exc.as_pydantic_error() from exc - return obj - - @model_validator(mode="after") - def check_primary_attribute_uniqueness(self, info: ValidationInfo) -> Self: - """Validate that only one attribute can be marked as primary in multi-valued lists. - - Per RFC 7643 Section 2.4: The primary attribute value 'true' MUST appear no more than once. - """ - scim_context = info.context.get("scim") if info.context else None - if not scim_context or scim_context == Context.DEFAULT: - return self - - for field_name in self.__class__.model_fields: - if not self.get_field_multiplicity(field_name): - continue + return - field_value = getattr(self, field_name) - if field_value is None: - continue - - element_type = self.get_field_root_type(field_name) - if ( - element_type is None - or not isclass(element_type) - or not issubclass(element_type, PydanticBaseModel) - or "primary" not in element_type.model_fields - ): - continue + primary_count = sum( + 1 for item in value if getattr(item, "primary", None) is True + ) - primary_count = sum( - 1 for item in field_value if getattr(item, "primary", None) is True + if primary_count > 1: + raise PydanticCustomError( + "primary_uniqueness_error", + "Field '{field_name}' has {count} items marked as primary, but only one is allowed per RFC 7643", + { + "field_name": field_name, + "count": primary_count, + }, ) - if primary_count > 1: - raise PydanticCustomError( - "primary_uniqueness_error", - "Field '{field_name}' has {count} items marked as primary, but only one is allowed per RFC 7643", - { - "field_name": field_name, - "count": primary_count, - }, - ) - - return self - def _apply_replace_constraints(self, original: Self) -> None: """Enforce RFC 7644 §3.5.1 replace (PUT) semantics. @@ -471,154 +465,165 @@ def _apply_replace_constraints(self, original: Self) -> None: Recursively applies to nested single-valued complex attributes. """ - from .attributes import is_complex_attribute - for field_name in type(self).model_fields: mutability = type(self).get_field_annotation(field_name, Mutability) original_val = getattr(original, field_name) if mutability == Mutability.read_only: # RFC 7644 §3.5.1: "readOnly" values provided SHALL be ignored. - setattr(self, field_name, original_val) + self.__dict__[field_name] = original_val elif mutability == Mutability.immutable: self_val = getattr(self, field_name) if self_val is None and original_val is not None: # RFC 7643 §7: "SHALL NOT be updated" — omitting an # immutable field is not a request to clear it. - setattr(self, field_name, original_val) + self.__dict__[field_name] = original_val elif self_val != original_val: # RFC 7644 §3.5.1: input values MUST match. raise MutabilityException( attribute=field_name, mutability="immutable" ) - attr_type = type(self).get_field_root_type(field_name) - if ( - attr_type - and is_complex_attribute(attr_type) - and not type(self).get_field_multiplicity(field_name) - ): - original_sub = getattr(original, field_name) - replacement_sub = getattr(self, field_name) + complex_and_extensions = self.__scim_info__.complex_fields.union( + self.__scim_info__.extensions + ) + for complex_attr in complex_and_extensions: + if not type(self).get_field_multiplicity(complex_attr): + original_sub = getattr(original, complex_attr) + replacement_sub = getattr(self, complex_attr) if original_sub is not None and replacement_sub is not None: replacement_sub._apply_replace_constraints(original_sub) - def _set_complex_attribute_urns(self) -> None: - """Navigate through attributes and sub-attributes of type ComplexAttribute, and mark them with a '_attribute_urn' attribute. + def get_attribute_urn(self, field_name: str) -> str: + """Build the full URN of the attribute. - '_attribute_urn' will later be used by 'get_attribute_urn'. + See :rfc:`RFC7644 §3.10 <7644#section-3.10>`. """ - from .attributes import ComplexAttribute - from .attributes import is_complex_attribute + return self.__scim_info__.attribute_urns[field_name] - if isinstance(self, ComplexAttribute): - main_schema = self._attribute_urn - separator = "." - else: - main_schema = getattr(self.__class__, "__schema__", None) - if main_schema is None: - return - separator = ":" + def _set_complex_attribute_urns(self) -> None: + """Mark each ``ComplexAttribute`` child with its ``_attribute_urn``. - for field_name in self.__class__.model_fields: - attr_type = self.get_field_root_type(field_name) - if not attr_type or not is_complex_attribute(attr_type): + ``_attribute_urn`` is later read by :meth:`get_attribute_urn`. + """ + cls = self.__class__ + info = cls.__scim_info__ + complex_fields = info.complex_fields + + for field_name in complex_fields: + attr_value = getattr(self, field_name) + if not attr_value: continue - alias = ( - self.__class__.model_fields[field_name].serialization_alias - or field_name - ) - schema = f"{main_schema}{separator}{alias}" + schema = info.attribute_urns[field_name] - if attr_value := getattr(self, field_name): - if isinstance(attr_value, list): - for item in attr_value: - item._attribute_urn = schema - else: - attr_value._attribute_urn = schema + if isinstance(attr_value, list): + for item in attr_value: + item._attribute_urn = schema + else: + attr_value._attribute_urn = schema - @field_serializer("*", mode="wrap") + @model_serializer(mode="wrap") def scim_serializer( - self, - value: Any, - handler: SerializerFunctionWrapHandler, - info: FieldSerializationInfo, - ) -> Any: + self, handler: SerializerFunctionWrapHandler, info: SerializationInfo + ) -> dict[str, Any]: """Serialize the fields according to mutability indications passed in the serialization context.""" - value = handler(value) scim_ctx = info.context.get("scim") if info.context else None + is_response = Context.is_response(scim_ctx) if scim_ctx else False - if scim_ctx and Context.is_request(scim_ctx): - value = self._scim_request_serializer(value, info) + if is_response: + # Complex attribute urns are only used in responses + self._set_complex_attribute_urns() - if scim_ctx and Context.is_response(scim_ctx): - value = self._scim_response_serializer(value, info) + serialized: dict[str, Any] = handler(self) - return value - - def _scim_request_serializer(self, value: Any, info: FieldSerializationInfo) -> Any: - """Serialize the fields according to mutability indications passed in the serialization context.""" - mutability = self.get_field_annotation(info.field_name, Mutability) - scim_ctx = info.context.get("scim") if info.context else None - - if ( - scim_ctx - in (Context.RESOURCE_CREATION_REQUEST, Context.RESOURCE_REPLACEMENT_REQUEST) - and mutability == Mutability.read_only - ): - return None + if not scim_ctx: + return serialized - if ( - scim_ctx - in ( - Context.RESOURCE_QUERY_REQUEST, - Context.SEARCH_REQUEST, + # Delete empty extensions + for extension_field in self.__scim_info__.extensions: + key = ( + self.__scim_info__.attribute_urns[extension_field] + if info.by_alias + else extension_field ) - and mutability == Mutability.write_only - ): - return None - - return value - - def _scim_response_serializer( - self, value: Any, info: FieldSerializationInfo - ) -> Any: - """Serialize the fields according to returnability indications passed in the serialization context.""" - returnability = self.get_field_annotation(info.field_name, Returned) - attribute_urn = self.get_attribute_urn(info.field_name) - included_attrs = info.context.get("scim_attributes", []) if info.context else [] - excluded_attrs = ( - info.context.get("scim_excluded_attributes", []) if info.context else [] - ) + if key in serialized and serialized[key] is None: + del serialized[key] + + # Serialize according to given context + if scim_ctx != Context.DEFAULT: + if is_response: + included_attrs = ( + info.context.get("scim_attributes", []) if info.context else [] + ) + excluded_attrs = ( + info.context.get("scim_excluded_attributes", []) + if info.context + else [] + ) + self._scim_response_serializer( + serialized, included_attrs, excluded_attrs + ) + else: + # Must be request + self._scim_request_serializer(serialized, scim_ctx) - if returnability == Returned.never: - return None + return serialized - if returnability == Returned.default and ( - ( - included_attrs - and not _is_attribute_requested(included_attrs, attribute_urn) - ) - or _exact_attr_match(excluded_attrs, attribute_urn) - ): - return None + def _scim_request_serializer( + self, serialized: dict[str, Any], scim_ctx: Context + ) -> None: + """Serialize the fields according to mutability indications passed in the serialization context.""" + for alias in set(serialized): + field_name = self.__scim_info__.alias_to_field.get(alias, alias) + mutability = self.get_field_annotation(field_name, Mutability) - if returnability == Returned.request and not _exact_attr_match( - included_attrs, attribute_urn - ): - return None + if ( + scim_ctx + in ( + Context.RESOURCE_CREATION_REQUEST, + Context.RESOURCE_REPLACEMENT_REQUEST, + ) + and mutability == Mutability.read_only + ): + del serialized[alias] - return value + elif ( + scim_ctx + in ( + Context.RESOURCE_QUERY_REQUEST, + Context.SEARCH_REQUEST, + ) + and mutability == Mutability.write_only + ): + del serialized[alias] - @model_serializer(mode="wrap") - def model_serializer_exclude_none( - self, handler: SerializerFunctionWrapHandler, info: SerializationInfo - ) -> dict[str, Any]: - """Remove `None` values inserted by the :meth:`~scim2_models.base.BaseModel.scim_serializer`.""" - self._set_complex_attribute_urns() - result = handler(self) - return {key: value for key, value in result.items() if value is not None} + def _scim_response_serializer( + self, + serialized: dict[str, Any], + included_attrs: list[str], + excluded_attrs: list[str], + ) -> None: + """Serialize the fields according to returnability indications passed in the serialization context.""" + for alias in set(serialized): + field_name = self.__scim_info__.alias_to_field.get(alias, alias) + returnability = self.get_field_annotation(field_name, Returned) + attribute_urn = self.get_attribute_urn(field_name) + + if returnability == Returned.never: + del serialized[alias] + elif returnability == Returned.default and ( + ( + included_attrs + and not _is_attribute_requested(included_attrs, attribute_urn) + ) + or _exact_attr_match(excluded_attrs, attribute_urn) + ): + del serialized[alias] + elif returnability == Returned.request and not _exact_attr_match( + included_attrs, attribute_urn + ): + del serialized[alias] @classmethod def model_validate( @@ -654,19 +659,80 @@ def model_validate( return super().model_validate(*args, **kwargs) - def get_attribute_urn(self, field_name: str) -> str: - """Build the full URN of the attribute. + def _prepare_model_dump( + self, + scim_ctx: Context | None = Context.DEFAULT, + attributes: list["str | Path[Any]"] | None = None, + excluded_attributes: list["str | Path[Any]"] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + kwargs.setdefault("context", {}).setdefault("scim", scim_ctx) - See :rfc:`RFC7644 §3.10 <7644#section-3.10>`. + if scim_ctx: + kwargs.setdefault("exclude_none", True) + kwargs.setdefault("by_alias", True) + + if attributes: + kwargs["context"]["scim_attributes"] = [str(a) for a in attributes] + if excluded_attributes: + kwargs["context"]["scim_excluded_attributes"] = [ + str(a) for a in excluded_attributes + ] + + return kwargs + + def model_dump( + self, + *args: Any, + scim_ctx: Context | None = Context.DEFAULT, + attributes: list["str | Path[Any]"] | None = None, + excluded_attributes: list["str | Path[Any]"] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`. + + :param scim_ctx: If a SCIM context is passed, some default values of + Pydantic :code:`BaseModel.model_dump` are tuned to generate valid SCIM + messages. Pass :data:`None` to get the default Pydantic behavior. + :param attributes: A multi-valued list of strings indicating the names of resource + attributes to return in the response, overriding the set of attributes that + would be returned by default. Invalid values are ignored. + :param excluded_attributes: A multi-valued list of strings indicating the names of resource + attributes to be removed from the default set of attributes to return. Invalid values are ignored. """ - from scim2_models.resources.resource import Extension - - main_schema = getattr(self.__class__, "__schema__", None) - field = self.__class__.model_fields[field_name] - alias = field.serialization_alias or field_name - field_type = self.get_field_root_type(field_name) - if isclass(field_type) and issubclass(field_type, Extension): - return alias - if main_schema is None: - return alias - return f"{main_schema}:{alias}" + dump_kwargs = self._prepare_model_dump( + scim_ctx, + attributes=attributes, + excluded_attributes=excluded_attributes, + **kwargs, + ) + if scim_ctx: + dump_kwargs.setdefault("mode", "json") + return super().model_dump(*args, **dump_kwargs) + + def model_dump_json( + self, + *args: Any, + scim_ctx: Context | None = Context.DEFAULT, + attributes: list["str | Path[Any]"] | None = None, + excluded_attributes: list["str | Path[Any]"] | None = None, + **kwargs: Any, + ) -> str: + """Create a JSON model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump_json`. + + :param scim_ctx: If a SCIM context is passed, some default values of + Pydantic :code:`BaseModel.model_dump` are tuned to generate valid SCIM + messages. Pass :data:`None` to get the default Pydantic behavior. + :param attributes: A multi-valued list of strings indicating the names of resource + attributes to return in the response, overriding the set of attributes that + would be returned by default. Invalid values are ignored. + :param excluded_attributes: A multi-valued list of strings indicating the names of resource + attributes to be removed from the default set of attributes to return. Invalid values are ignored. + """ + dump_kwargs = self._prepare_model_dump( + scim_ctx, + attributes=attributes, + excluded_attributes=excluded_attributes, + **kwargs, + ) + return super().model_dump_json(*args, **dump_kwargs) diff --git a/scim2_models/messages/message.py b/scim2_models/messages/message.py index 5b005d4..2b75665 100644 --- a/scim2_models/messages/message.py +++ b/scim2_models/messages/message.py @@ -1,5 +1,4 @@ from collections.abc import Callable -from typing import TYPE_CHECKING from typing import Annotated from typing import Any from typing import Union @@ -16,18 +15,12 @@ from ..scim_object import ScimObject from ..utils import UNION_TYPES -if TYPE_CHECKING: - from pydantic import FieldSerializationInfo - class Message(ScimObject): """SCIM protocol messages as defined by :rfc:`RFC7644 §3.1 <7644#section-3.1>`.""" - def _scim_response_serializer( - self, value: Any, info: "FieldSerializationInfo" - ) -> Any: + def _scim_response_serializer(self, *args: Any, **kwargs: Any) -> None: """Message fields are not subject to attribute filtering.""" - return value def _create_schema_discriminator( diff --git a/scim2_models/resources/resource.py b/scim2_models/resources/resource.py index e0c13b7..278bb86 100644 --- a/scim2_models/resources/resource.py +++ b/scim2_models/resources/resource.py @@ -112,7 +112,7 @@ def from_schema(cls, schema: "Schema") -> type["Extension"]: def _extension_serializer( value: Any, handler: SerializerFunctionWrapHandler, info: SerializationInfo -) -> dict[str, Any] | None: +) -> Any: """Exclude the Resource attributes from the extension dump. For instance, attributes 'meta', 'id' or 'schemas' should not be @@ -123,6 +123,10 @@ def _extension_serializer( partial_result = handler(value) + scim_context = info.context.get("scim") if info.context else None + if not scim_context: + return partial_result + result = { attr_name: value for attr_name, value in partial_result.items() diff --git a/scim2_models/scim_object.py b/scim2_models/scim_object.py index b48121e..434d734 100644 --- a/scim2_models/scim_object.py +++ b/scim2_models/scim_object.py @@ -1,7 +1,6 @@ """Base SCIM object classes with schema identification.""" import warnings -from typing import TYPE_CHECKING from typing import Annotated from typing import Any from typing import ClassVar @@ -14,13 +13,10 @@ from typing_extensions import Self from .annotations import Required +from .annotations import Returned from .base import BaseModel from .context import Context from .path import URN -from .path import Path - -if TYPE_CHECKING: - pass class ScimMetaclass(ModelMetaclass): @@ -65,7 +61,7 @@ def __new__( class ScimObject(BaseModel, metaclass=ScimMetaclass): __schema__: ClassVar[URN | None] = None - schemas: Annotated[list[str], Required.true] + schemas: Annotated[list[str], Required.true, Returned.always] """The "schemas" attribute is a REQUIRED attribute and is an array of Strings containing URIs that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON @@ -102,81 +98,3 @@ def _validate_schemas_attribute( ) return obj - - def _prepare_model_dump( - self, - scim_ctx: Context | None = Context.DEFAULT, - attributes: list[str | Path[Any]] | None = None, - excluded_attributes: list[str | Path[Any]] | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - kwargs.setdefault("context", {}).setdefault("scim", scim_ctx) - - if scim_ctx: - kwargs.setdefault("exclude_none", True) - kwargs.setdefault("by_alias", True) - - if attributes: - kwargs["context"]["scim_attributes"] = [str(a) for a in attributes] - if excluded_attributes: - kwargs["context"]["scim_excluded_attributes"] = [ - str(a) for a in excluded_attributes - ] - - return kwargs - - def model_dump( - self, - *args: Any, - scim_ctx: Context | None = Context.DEFAULT, - attributes: list[str | Path[Any]] | None = None, - excluded_attributes: list[str | Path[Any]] | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - """Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`. - - :param scim_ctx: If a SCIM context is passed, some default values of - Pydantic :code:`BaseModel.model_dump` are tuned to generate valid SCIM - messages. Pass :data:`None` to get the default Pydantic behavior. - :param attributes: A multi-valued list of strings indicating the names of resource - attributes to return in the response, overriding the set of attributes that - would be returned by default. Invalid values are ignored. - :param excluded_attributes: A multi-valued list of strings indicating the names of resource - attributes to be removed from the default set of attributes to return. Invalid values are ignored. - """ - dump_kwargs = self._prepare_model_dump( - scim_ctx, - attributes=attributes, - excluded_attributes=excluded_attributes, - **kwargs, - ) - if scim_ctx: - dump_kwargs.setdefault("mode", "json") - return super(BaseModel, self).model_dump(*args, **dump_kwargs) - - def model_dump_json( - self, - *args: Any, - scim_ctx: Context | None = Context.DEFAULT, - attributes: list[str | Path[Any]] | None = None, - excluded_attributes: list[str | Path[Any]] | None = None, - **kwargs: Any, - ) -> str: - """Create a JSON model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump_json`. - - :param scim_ctx: If a SCIM context is passed, some default values of - Pydantic :code:`BaseModel.model_dump` are tuned to generate valid SCIM - messages. Pass :data:`None` to get the default Pydantic behavior. - :param attributes: A multi-valued list of strings indicating the names of resource - attributes to return in the response, overriding the set of attributes that - would be returned by default. Invalid values are ignored. - :param excluded_attributes: A multi-valued list of strings indicating the names of resource - attributes to be removed from the default set of attributes to return. Invalid values are ignored. - """ - dump_kwargs = self._prepare_model_dump( - scim_ctx, - attributes=attributes, - excluded_attributes=excluded_attributes, - **kwargs, - ) - return super(BaseModel, self).model_dump_json(*args, **dump_kwargs) diff --git a/scim2_models/utils.py b/scim2_models/utils.py index 4bddfba..70b2ace 100644 --- a/scim2_models/utils.py +++ b/scim2_models/utils.py @@ -1,4 +1,5 @@ import re +from functools import lru_cache from typing import TYPE_CHECKING from typing import Union @@ -36,6 +37,7 @@ def _to_camel(string: str) -> str: return camel +@lru_cache(maxsize=256) def _normalize_attribute_name(attribute_name: str) -> str: """Remove all non-alphabetical characters and lowerise a string. diff --git a/tests/test_model_serialization.py b/tests/test_model_serialization.py index 5f41aa1..44edbcd 100644 --- a/tests/test_model_serialization.py +++ b/tests/test_model_serialization.py @@ -4,10 +4,10 @@ from scim2_models import URN from scim2_models.annotations import Mutability -from scim2_models.annotations import Required from scim2_models.annotations import Returned from scim2_models.attributes import ComplexAttribute from scim2_models.context import Context +from scim2_models.resources.resource import Extension from scim2_models.resources.resource import Resource @@ -30,7 +30,16 @@ class SupRetResource(Resource): class MutResource(Resource): - schemas: Annotated[list[str], Required.true] = ["org:example:MutResource"] + __schema__ = URN("urn:org:example:MutResource") + + read_only: Annotated[str | None, Mutability.read_only] = None + read_write: Annotated[str | None, Mutability.read_write] = None + immutable: Annotated[str | None, Mutability.immutable] = None + write_only: Annotated[str | None, Mutability.write_only] = None + + +class MutExtension(Extension): + __schema__ = URN("urn:org:extensions:MutExtension") read_only: Annotated[str | None, Mutability.read_only] = None read_write: Annotated[str | None, Mutability.read_write] = None @@ -66,17 +75,48 @@ def mut_resource(): ) +@pytest.fixture +def mut_resource_extension(): + resource = MutResource[MutExtension]( + id="id", + read_only="x", + read_write="x", + immutable="x", + write_only="x", + ) + resource[MutExtension] = MutExtension( + read_only="y", + read_write="y", + immutable="y", + write_only="y", + ) + return resource + + +@pytest.fixture +def mut_resource_extension_empty(): + resource = MutResource[MutExtension]( + id="id", + read_only="x", + read_write="x", + immutable="x", + write_only="x", + ) + resource[MutExtension] = MutExtension() + return resource + + def test_model_dump_json(mut_resource): assert ( mut_resource.model_dump_json() - == '{"schemas":["org:example:MutResource"],"id":"id","readOnly":"x","readWrite":"x","immutable":"x","writeOnly":"x"}' + == '{"schemas":["urn:org:example:MutResource"],"id":"id","readOnly":"x","readWrite":"x","immutable":"x","writeOnly":"x"}' ) def test_dump_default(mut_resource): """By default, everything is dumped.""" assert mut_resource.model_dump() == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "id": "id", "readOnly": "x", "readWrite": "x", @@ -85,7 +125,7 @@ def test_dump_default(mut_resource): } assert mut_resource.model_dump(scim_ctx=Context.DEFAULT) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "id": "id", "readOnly": "x", "readWrite": "x", @@ -94,12 +134,177 @@ def test_dump_default(mut_resource): } assert mut_resource.model_dump(scim_ctx=None) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], + "id": "id", + "external_id": None, + "meta": None, + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + } + + assert mut_resource.model_dump(scim_ctx=None, exclude_none=True) == { + "schemas": ["urn:org:example:MutResource"], + "id": "id", + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + } + + +def test_dump_extension(mut_resource_extension): + """Test dumps with extension.""" + assert mut_resource_extension.model_dump() == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "readOnly": "x", + "readWrite": "x", + "immutable": "x", + "writeOnly": "x", + "urn:org:extensions:MutExtension": { + "readOnly": "y", + "readWrite": "y", + "immutable": "y", + "writeOnly": "y", + }, + } + + assert mut_resource_extension.model_dump(scim_ctx=None) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "external_id": None, + "meta": None, + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + "MutExtension": { + "schemas": ["urn:org:extensions:MutExtension"], + "read_only": "y", + "read_write": "y", + "immutable": "y", + "write_only": "y", + }, + } + + assert mut_resource_extension.model_dump(scim_ctx=None, exclude_none=True) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], "id": "id", "read_only": "x", "read_write": "x", "immutable": "x", "write_only": "x", + "MutExtension": { + "schemas": ["urn:org:extensions:MutExtension"], + "read_only": "y", + "read_write": "y", + "immutable": "y", + "write_only": "y", + }, + } + + assert mut_resource_extension.model_dump(by_alias=False) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + "MutExtension": { + "read_only": "y", + "read_write": "y", + "immutable": "y", + "write_only": "y", + }, + } + + assert mut_resource_extension.model_dump(scim_ctx=None, by_alias=True) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "externalId": None, + "meta": None, + "readOnly": "x", + "readWrite": "x", + "immutable": "x", + "writeOnly": "x", + "urn:org:extensions:MutExtension": { + "schemas": ["urn:org:extensions:MutExtension"], + "readOnly": "y", + "readWrite": "y", + "immutable": "y", + "writeOnly": "y", + }, + } + + +def test_dump_empty_extension(mut_resource_extension_empty): + """Test dumps with empty extension.""" + assert mut_resource_extension_empty.model_dump() == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "readOnly": "x", + "readWrite": "x", + "immutable": "x", + "writeOnly": "x", + } + + assert mut_resource_extension_empty.model_dump(scim_ctx=None) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "external_id": None, + "meta": None, + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + "MutExtension": { + "schemas": ["urn:org:extensions:MutExtension"], + "read_only": None, + "read_write": None, + "immutable": None, + "write_only": None, + }, + } + + assert mut_resource_extension_empty.model_dump( + scim_ctx=None, exclude_none=True + ) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + "MutExtension": {"schemas": ["urn:org:extensions:MutExtension"]}, + } + + assert mut_resource_extension_empty.model_dump(by_alias=False) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "read_only": "x", + "read_write": "x", + "immutable": "x", + "write_only": "x", + } + + assert mut_resource_extension_empty.model_dump(scim_ctx=None, by_alias=True) == { + "schemas": ["urn:org:example:MutResource", "urn:org:extensions:MutExtension"], + "id": "id", + "externalId": None, + "meta": None, + "readOnly": "x", + "readWrite": "x", + "immutable": "x", + "writeOnly": "x", + "urn:org:extensions:MutExtension": { + "schemas": ["urn:org:extensions:MutExtension"], + "readOnly": None, + "readWrite": None, + "immutable": None, + "writeOnly": None, + }, } @@ -113,7 +318,7 @@ def test_dump_creation_request(mut_resource): - Mutability.read_only are not dumped """ assert mut_resource.model_dump(scim_ctx=Context.RESOURCE_CREATION_REQUEST) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "readWrite": "x", "immutable": "x", "writeOnly": "x", @@ -130,7 +335,7 @@ def test_dump_query_request(mut_resource): - Mutability.read_only are dumped """ assert mut_resource.model_dump(scim_ctx=Context.RESOURCE_QUERY_REQUEST) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "id": "id", "readOnly": "x", "readWrite": "x", @@ -148,7 +353,7 @@ def test_dump_replacement_request(mut_resource): - Mutability.read_only are not dumped """ assert mut_resource.model_dump(scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "readWrite": "x", "writeOnly": "x", "immutable": "x", @@ -165,7 +370,7 @@ def test_dump_search_request(mut_resource): - Mutability.read_only are dumped """ assert mut_resource.model_dump(scim_ctx=Context.RESOURCE_QUERY_REQUEST) == { - "schemas": ["org:example:MutResource"], + "schemas": ["urn:org:example:MutResource"], "id": "id", "readOnly": "x", "readWrite": "x", diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py index 0c3d6f8..7a0350e 100644 --- a/tests/test_model_validation.py +++ b/tests/test_model_validation.py @@ -332,6 +332,49 @@ class Super(Resource): assert replacement.sub.read_write == "new" +def test_replace_detects_changed_immutable_in_extension(): + """Replace detects changes in immutable fields inside extensions.""" + from scim2_models import URN + from scim2_models import Extension + from scim2_models.exceptions import MutabilityException + + class MyExt(Extension): + __schema__ = URN("urn:example:extensions:2.0:MyExt") + immutable: Annotated[str | None, Mutability.immutable] = None + + class MyResource(Resource): + __schema__ = URN("urn:example:resources:2.0:MyResource") + + original = MyResource[MyExt]() + original[MyExt] = MyExt(immutable="x") + replacement = MyResource[MyExt]() + replacement[MyExt] = MyExt(immutable="y") + with pytest.raises(MutabilityException): + replacement.replace(original) + + +def test_replace_copies_read_only_in_extension(): + """Replace copies readOnly fields from original inside extensions.""" + from scim2_models import URN + from scim2_models import Extension + + class MyExt(Extension): + __schema__ = URN("urn:example:extensions:2.0:MyExt") + read_only: Annotated[str | None, Mutability.read_only] = None + read_write: Annotated[str | None, Mutability.read_write] = None + + class MyResource(Resource): + __schema__ = URN("urn:example:resources:2.0:MyResource") + + original = MyResource[MyExt]() + original[MyExt] = MyExt(read_only="server", read_write="old") + replacement = MyResource[MyExt]() + replacement[MyExt] = MyExt(read_only="client", read_write="new") + replacement.replace(original) + assert replacement[MyExt].read_only == "server" + assert replacement[MyExt].read_write == "new" + + def test_original_parameter_emits_deprecation_warning(): """Passing 'original' to model_validate emits a DeprecationWarning.""" original = MutResource(immutable="y") diff --git a/tests/test_reference.py b/tests/test_reference.py index a637bc6..ae07cd3 100644 --- a/tests/test_reference.py +++ b/tests/test_reference.py @@ -75,7 +75,7 @@ def test_reference_serialization(): model = ReferenceTestModel(uri_ref=ref) dumped = model.model_dump() - assert dumped["uri_ref"] == "https://example.com" + assert dumped["uriRef"] == "https://example.com" def test_reference_validation_error():