diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 8bce841..4e02e29 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,15 +1,16 @@ import inspect +import importlib from types import GenericAlias -from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload -from typing_extensions import Optional, Self, TypeAlias, dataclass_transform, get_origin, get_args +from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload, ForwardRef +from typing_extensions import Optional, Self, TypeAlias, dataclass_transform from .coder import Proto, proto_decode, proto_encode -_ProtoTypes = Union[str, list, dict, bytes, int, float, bool, "ProtoStruct"] +_ProtoBasicTypes = Union[str, list, dict, bytes, int, float, bool] +_ProtoTypes = Union[_ProtoBasicTypes, "ProtoStruct"] T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct") V = TypeVar("V") -NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] NoneType = type(None) @@ -17,8 +18,8 @@ class ProtoField(Generic[T]): def __init__(self, tag: int, default: T): if tag <= 0: raise ValueError("Tag must be a positive integer") - self._tag = tag - self._default = default + self._tag: int = tag + self._default: T = default @property def tag(self) -> int: @@ -79,17 +80,27 @@ def proto_field( return ProtoField(tag, default) +NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] +AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] +PS = TypeVar("PS", bound=ProtoField) +DAMT: Union[Type[list[ForwardRef]], ForwardRef] +DAMDT: TypeAlias = Dict[str, Union[Type[list[ForwardRef]], ForwardRef]] + + +# noinspection PyProtectedMember @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: _anno_map: Dict[str, Tuple[Type[_ProtoTypes], ProtoField[Any]]] + _delay_anno_map: DAMDT _proto_debug: bool def __init__(self, *args, **kwargs): undefined_params: List[str] = [] - args = list(args) + args_list = list(args) + self._resolve_annotations(self) for name, (typ, field) in self._anno_map.items(): if args: - self._set_attr(name, typ, args.pop(0)) + self._set_attr(name, typ, args_list.pop(0)) elif name in kwargs: self._set_attr(name, typ, kwargs.pop(name)) else: @@ -98,13 +109,11 @@ def __init__(self, *args, **kwargs): else: undefined_params.append(name) if undefined_params: - raise AttributeError( - "Undefined parameters in '{}': {}".format(self, undefined_params) - ) + raise AttributeError(f"Undefined parameters in {self}: {undefined_params}") def __init_subclass__(cls, **kwargs): - cls._anno_map = cls._get_annotations() cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False + cls._anno_map, cls._delay_anno_map = cls._get_annotations() super().__init_subclass__(**kwargs) def __repr__(self) -> str: @@ -119,16 +128,19 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None: if isinstance(data_typ, GenericAlias): # force ignore pass elif not isinstance(value, data_typ) and value is not None: - raise TypeError( - "'{}' is not a instance of type '{}'".format(value, data_typ) - ) + raise TypeError("{value} is not a instance of type {data_typ}") setattr(self, name, value) @classmethod - def _get_annotations( - cls, - ) -> Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField) - annotations: Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] = {} + def _handle_inner_generic(cls, inner: GenericAlias) -> GenericAlias: + if inner.__origin__ is list: + return GenericAlias(list, ForwardRef(inner.__args__[0])) + raise NotImplementedError(f"unknown inner generic type '{inner}'") + + @classmethod + def _get_annotations(cls) -> Tuple[AMT, DAMDT]: # Name: (ReturnType, ProtoField) + annotations: AMT = {} + delay_annotations: DAMDT = {} for obj in reversed(inspect.getmro(cls)): if obj in (ProtoStruct, object): # base object, ignore continue @@ -142,15 +154,34 @@ def _get_annotations( if not isinstance(field, ProtoField): raise TypeError("attribute '{name}' is not a ProtoField object") + _typ = typ + annotations[name] = (_typ, field) + if isinstance(typ, str): + delay_annotations[name] = ForwardRef(typ) if hasattr(typ, "__origin__"): - typ = typ.__origin__[typ.__args__[0]] - annotations[name] = (typ, field) - - return annotations + typ = cast(GenericAlias, typ) + _inner = typ.__args__[0] + _typ = typ.__origin__[typ.__args__[0]] + annotations[name] = (_typ, field) + + if isinstance(_inner, type): + continue + if isinstance(_inner, GenericAlias) and isinstance(_inner.__args__[0], type): + continue + if isinstance(_inner, str): + delay_annotations[name] = _typ.__origin__[ForwardRef(_inner)] + if isinstance(_inner, ForwardRef): + delay_annotations[name] = _inner + if isinstance(_inner, GenericAlias): + delay_annotations[name] = cast(Type[list[ForwardRef]], cls._handle_inner_generic(_inner)) + + return annotations, delay_annotations @classmethod def _get_field_mapping(cls) -> Dict[int, Tuple[str, Type[_ProtoTypes]]]: # Tag, (Name, Type) field_mapping: Dict[int, Tuple[str, Type[_ProtoTypes]]] = {} + if cls._delay_anno_map: + cls._resolve_annotations(cls) for name, (typ, field) in cls._anno_map.items(): field_mapping[field.tag] = (name, typ) return field_mapping @@ -161,7 +192,22 @@ def _get_stored_mapping(self) -> Dict[str, NT]: stored_mapping[name] = getattr(self, name) return stored_mapping - def _encode(self, v: _ProtoTypes) -> NT: + @staticmethod + def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None: + if not arg._delay_anno_map: + return + local = importlib.import_module(arg.__module__).__dict__ + for k, v in arg._delay_anno_map.copy().items(): + casted_forward: Type["ProtoStruct"] + if isinstance(v, GenericAlias): + casted_forward = v.__origin__[v.__args__[0]._evaluate(globals(), local, recursive_guard=frozenset())] + arg._anno_map[k] = (casted_forward, arg._anno_map[k][1]) + if isinstance(v, ForwardRef): + casted_forward = v._evaluate(globals(), local, recursive_guard=frozenset()) # type: ignore + arg._anno_map[k] = (casted_forward, arg._anno_map[k][1]) + arg._delay_anno_map.pop(k) + + def _encode(self, v: _ProtoTypes) -> _ProtoBasicTypes: if isinstance(v, ProtoStruct): v = v.encode() return v @@ -230,4 +276,3 @@ def decode(cls, data: bytes) -> Self: return cls(**kwargs) -