Skip to content

Commit e572b2b

Browse files
authored
feat: add msgspec support (#154)
This PR adds an optional support for msgspec
1 parent 14f29dd commit e572b2b

File tree

5 files changed

+138
-10
lines changed

5 files changed

+138
-10
lines changed

examples/msgspec_greeter.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""msgspec_greeter.py - Example using msgspec.Struct with Restate"""
12+
# pylint: disable=C0116
13+
# pylint: disable=W0613
14+
# pylint: disable=C0115
15+
# pylint: disable=R0903
16+
17+
import msgspec
18+
from restate import Service, Context
19+
20+
21+
# models
22+
class GreetingRequest(msgspec.Struct):
23+
name: str
24+
25+
26+
class Greeting(msgspec.Struct):
27+
message: str
28+
29+
30+
# service
31+
32+
msgspec_greeter = Service("msgspec_greeter")
33+
34+
35+
@msgspec_greeter.handler()
36+
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
37+
return Greeting(message=f"Hello {req.name}!")
38+

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Source = "https://github.com/restatedev/sdk-python"
2424
test = ["pytest", "hypercorn", "anyio"]
2525
lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"]
2626
harness = ["testcontainers", "hypercorn", "httpx"]
27-
serde = ["dacite", "pydantic"]
27+
serde = ["dacite", "pydantic", "msgspec"]
2828
client = ["httpx[http2]"]
2929

3030
[build-system]

python/restate/discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
218218
return None
219219
if not type_hint.annotation:
220220
return None
221+
if type_hint.is_msgspec:
222+
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
223+
224+
return msgspec.json.schema(type_hint.annotation)
221225
if type_hint.is_pydantic:
222226
return type_hint.annotation.model_json_schema(mode="serialization")
223227
return type_hint_to_json_schema(type_hint.annotation)

python/restate/handler.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from restate.context import HandlerType
2626
from restate.exceptions import TerminalError
27-
from restate.serde import DefaultSerde, PydanticJsonSerde, Serde, is_pydantic
27+
from restate.serde import DefaultSerde, PydanticJsonSerde, MsgspecJsonSerde, Serde, is_pydantic, is_msgspec
2828

2929
I = TypeVar("I")
3030
O = TypeVar("O")
@@ -54,6 +54,7 @@ class TypeHint(Generic[T]):
5454

5555
annotation: Optional[T] = None
5656
is_pydantic: bool = False
57+
is_msgspec: bool = False
5758
is_void: bool = False
5859

5960

@@ -79,20 +80,24 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si
7980
"""
8081
Augment handler_io with additional information about the input and output types.
8182
82-
This function has a special check for Pydantic models when these are provided.
83+
This function has a special check for msgspec Structs and Pydantic models when these are provided.
8384
This method will inspect the signature of an handler and will look for
8485
the input and the return types of a function, and will:
85-
* capture any Pydantic models (to be used later at discovery)
86-
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
86+
* capture any msgspec Structs or Pydantic models (to be used later at discovery)
87+
* replace the default json serializer (is unchanged by a user) with the appropriate serde
8788
"""
8889
params = list(signature.parameters.values())
8990
if len(params) == 1:
9091
# if there is only one parameter, it is the context.
9192
handler_io.input_type = TypeHint(is_void=True)
9293
else:
9394
annotation = params[-1].annotation
94-
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False)
95-
if is_pydantic(annotation):
95+
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
96+
if is_msgspec(annotation):
97+
handler_io.input_type.is_msgspec = True
98+
if isinstance(handler_io.input_serde, DefaultSerde):
99+
handler_io.input_serde = MsgspecJsonSerde(annotation)
100+
elif is_pydantic(annotation):
96101
handler_io.input_type.is_pydantic = True
97102
if isinstance(handler_io.input_serde, DefaultSerde):
98103
handler_io.input_serde = PydanticJsonSerde(annotation)
@@ -102,8 +107,12 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si
102107
# if there is no return annotation, we assume it is void
103108
handler_io.output_type = TypeHint(is_void=True)
104109
else:
105-
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)
106-
if is_pydantic(annotation):
110+
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False, is_msgspec=False)
111+
if is_msgspec(annotation):
112+
handler_io.output_type.is_msgspec = True
113+
if isinstance(handler_io.output_serde, DefaultSerde):
114+
handler_io.output_serde = MsgspecJsonSerde(annotation)
115+
elif is_pydantic(annotation):
107116
handler_io.output_type.is_pydantic = True
108117
if isinstance(handler_io.output_serde, DefaultSerde):
109118
handler_io.output_serde = PydanticJsonSerde(annotation)

python/restate/serde.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,24 @@ def _from_dict(data_class: typing.Any, data: typing.Any) -> typing.Any: # pylin
7474
return _to_dict, _from_dict
7575

7676

77+
def try_import_msgspec_struct():
78+
"""
79+
Try to import Struct from msgspec.
80+
"""
81+
try:
82+
from msgspec import Struct # type: ignore # pylint: disable=import-outside-toplevel
83+
84+
return Struct
85+
except ImportError:
86+
87+
class Dummy: # pylint: disable=too-few-public-methods
88+
"""a dummy class to use when msgspec is not available"""
89+
90+
return Dummy
91+
92+
7793
PydanticBaseModel = try_import_pydantic_base_model()
94+
MsgspecStruct = try_import_msgspec_struct()
7895
# pylint: disable=C0103
7996
DaciteToDict, DaciteFromDict = try_import_from_dacite()
8097

@@ -97,6 +114,17 @@ def is_pydantic(annotation) -> bool:
97114
return False
98115

99116

117+
def is_msgspec(annotation) -> bool:
118+
"""
119+
Check if an object is a msgspec Struct.
120+
"""
121+
try:
122+
return issubclass(annotation, MsgspecStruct)
123+
except TypeError:
124+
# annotation is not a class or a type
125+
return False
126+
127+
100128
class Serde(typing.Generic[T], abc.ABC):
101129
"""serializer/deserializer interface."""
102130

@@ -227,6 +255,10 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
227255
"""
228256
if not buf:
229257
return None
258+
if is_msgspec(self.type_hint):
259+
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
260+
261+
return msgspec.json.decode(buf, type=self.type_hint)
230262
if is_pydantic(self.type_hint):
231263
return self.type_hint.model_validate_json(buf) # type: ignore
232264
if is_dataclass(self.type_hint):
@@ -237,7 +269,7 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
237269
def serialize(self, obj: typing.Optional[I]) -> bytes:
238270
"""
239271
Serializes a Python object into a byte array.
240-
If the object is a Pydantic BaseModel, uses its model_dump_json method.
272+
If the object is a msgspec Struct or Pydantic BaseModel, uses their respective methods.
241273
242274
Args:
243275
obj (Optional[I]): The Python object to serialize.
@@ -247,6 +279,10 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
247279
"""
248280
if obj is None:
249281
return bytes()
282+
if is_msgspec(self.type_hint):
283+
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
284+
285+
return msgspec.json.encode(obj)
250286
if is_pydantic(self.type_hint):
251287
return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined]
252288
if is_dataclass(obj):
@@ -291,3 +327,44 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
291327
return bytes()
292328
json_str = obj.model_dump_json() # type: ignore[attr-defined]
293329
return json_str.encode("utf-8")
330+
331+
332+
class MsgspecJsonSerde(Serde[I]):
333+
"""
334+
Serde for msgspec Structs to/from JSON
335+
"""
336+
337+
def __init__(self, model):
338+
self.model = model
339+
340+
def deserialize(self, buf: bytes) -> typing.Optional[I]:
341+
"""
342+
Deserializes a bytearray to a msgspec Struct.
343+
344+
Args:
345+
buf (bytearray): The bytearray to deserialize.
346+
347+
Returns:
348+
typing.Optional[I]: The deserialized msgspec Struct.
349+
"""
350+
if not buf:
351+
return None
352+
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
353+
354+
return msgspec.json.decode(buf, type=self.model)
355+
356+
def serialize(self, obj: typing.Optional[I]) -> bytes:
357+
"""
358+
Serializes a msgspec Struct to a bytearray.
359+
360+
Args:
361+
obj (I): The msgspec Struct to serialize.
362+
363+
Returns:
364+
bytearray: The serialized bytearray.
365+
"""
366+
if obj is None:
367+
return bytes()
368+
import msgspec.json # type: ignore # pylint: disable=import-outside-toplevel
369+
370+
return msgspec.json.encode(obj)

0 commit comments

Comments
 (0)