diff --git a/starknet_py/hash/outside_execution.py b/starknet_py/hash/outside_execution.py index a5c5f8f38..cd426b765 100644 --- a/starknet_py/hash/outside_execution.py +++ b/starknet_py/hash/outside_execution.py @@ -1,6 +1,6 @@ from starknet_py.constants import OutsideExecutionInterfaceID from starknet_py.net.client_models import OutsideExecution -from starknet_py.net.schemas.common import Revision +from starknet_py.net.models.typed_data import Revision from starknet_py.utils.typed_data import TypedData OUTSIDE_EXECUTION_INTERFACE_ID_TO_TYPED_DATA_REVISION = { diff --git a/starknet_py/net/models/typed_data.py b/starknet_py/net/models/typed_data.py index 972391b29..c95b3322f 100644 --- a/starknet_py/net/models/typed_data.py +++ b/starknet_py/net/models/typed_data.py @@ -3,16 +3,24 @@ """ import sys +from enum import Enum from typing import Any, Dict, List, Optional, TypedDict -from starknet_py.net.schemas.common import Revision - if sys.version_info < (3, 11): from typing_extensions import NotRequired else: from typing import NotRequired +class Revision(Enum): + """ + Enum representing the revision of the specification to be used. + """ + + V0 = 0 + V1 = 1 + + class ParameterDict(TypedDict): """ TypedDict representing a Parameter object diff --git a/starknet_py/net/schemas/common.py b/starknet_py/net/schemas/common.py index 7f296a847..7653ddfff 100644 --- a/starknet_py/net/schemas/common.py +++ b/starknet_py/net/schemas/common.py @@ -1,6 +1,5 @@ import re import sys -from enum import Enum from typing import Any, Mapping, Optional, Union from marshmallow import Schema, ValidationError, fields, post_load @@ -377,35 +376,3 @@ class StorageEntrySchema(Schema): def make_dataclass(self, data, **kwargs): # pylint: disable=no-self-use return StorageEntry(**data) - - -class Revision(Enum): - """ - Enum representing the revision of the specification to be used. - """ - - V0 = 0 - V1 = 1 - - -class RevisionField(fields.Field): - def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs): - if value is None or value == Revision.V0: - return str(Revision.V0.value) - return value.value - - def _deserialize(self, value, attr, data, **kwargs) -> Revision: - if isinstance(value, str): - value = int(value) - - if isinstance(value, Revision): - value = value.value - - revisions = [revision.value for revision in Revision] - if value not in revisions: - allowed_revisions_str = "".join(list(map(str, revisions))) - raise ValidationError( - f"Invalid value provided for Revision: {value}. Allowed values are {allowed_revisions_str}." - ) - - return Revision(value) diff --git a/starknet_py/net/schemas/revision.py b/starknet_py/net/schemas/revision.py new file mode 100644 index 000000000..ec52ab5fd --- /dev/null +++ b/starknet_py/net/schemas/revision.py @@ -0,0 +1,28 @@ +from typing import Any, Optional + +from marshmallow import ValidationError, fields + +from starknet_py.net.models.typed_data import Revision + + +class RevisionField(fields.Field): + def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs): + if value is None or value == Revision.V0: + return str(Revision.V0.value) + return value.value + + def _deserialize(self, value, attr, data, **kwargs) -> Revision: + if isinstance(value, str): + value = int(value) + + if isinstance(value, Revision): + value = value.value + + revisions = [revision.value for revision in Revision] + if value not in revisions: + allowed_revisions_str = "".join(list(map(str, revisions))) + raise ValidationError( + f"Invalid value provided for Revision: {value}. Allowed values are {allowed_revisions_str}." + ) + + return Revision(value) diff --git a/starknet_py/utils/typed_data.py b/starknet_py/utils/typed_data.py index ebdd90020..8bd9f0acc 100644 --- a/starknet_py/utils/typed_data.py +++ b/starknet_py/utils/typed_data.py @@ -12,7 +12,7 @@ from starknet_py.hash.selector import get_selector_from_name from starknet_py.net.client_utils import _to_rpc_felt from starknet_py.net.models.typed_data import DomainDict, Revision, TypedDataDict -from starknet_py.net.schemas.common import RevisionField +from starknet_py.net.schemas.revision import RevisionField from starknet_py.serialization.data_serializers import ByteArraySerializer from starknet_py.utils.merkle_tree import MerkleTree