Skip to content

Commit 6ec4a76

Browse files
dbczumarisaacbmiller
authored andcommitted
Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) (#1853)
* Add Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> --------- Signed-off-by: dbczumar <[email protected]>
1 parent 4699adf commit 6ec4a76

File tree

8 files changed

+304
-183
lines changed

8 files changed

+304
-183
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 3 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from dsp.adapters.base_template import Field
1616
from dspy.adapters.base import Adapter
17-
from dspy.adapters.image_utils import Image, encode_image
17+
from dspy.adapters.utils import find_enum_member, format_field_value
1818
from dspy.signatures.field import OutputField
1919
from dspy.signatures.signature import Signature, SignatureMeta
2020
from dspy.signatures.utils import get_dspy_field_type
@@ -114,99 +114,6 @@ def format_fields(self, signature, values, role):
114114
return format_fields(fields_with_values)
115115

116116

117-
def format_blob(blob):
118-
if "\n" not in blob and "«" not in blob and "»" not in blob:
119-
return f"«{blob}»"
120-
121-
modified_blob = blob.replace("\n", "\n ")
122-
return f"«««\n {modified_blob}\n»»»"
123-
124-
125-
def format_input_list_field_value(value: List[Any]) -> str:
126-
"""
127-
Formats the value of an input field of type List[Any].
128-
129-
Args:
130-
value: The value of the list-type input field.
131-
Returns:
132-
A string representation of the input field's list value.
133-
"""
134-
if len(value) == 0:
135-
return "N/A"
136-
if len(value) == 1:
137-
return format_blob(value[0])
138-
139-
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
140-
141-
142-
def _serialize_for_json(value):
143-
if isinstance(value, pydantic.BaseModel):
144-
return value.model_dump()
145-
elif isinstance(value, list):
146-
return [_serialize_for_json(item) for item in value]
147-
elif isinstance(value, dict):
148-
return {key: _serialize_for_json(val) for key, val in value.items()}
149-
else:
150-
return value
151-
152-
153-
def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]:
154-
"""
155-
Formats the value of the specified field according to the field's DSPy type (input or output),
156-
annotation (e.g. str, int, etc.), and the type of the value itself.
157-
158-
Args:
159-
field_info: Information about the field, including its DSPy field type and annotation.
160-
value: The value of the field.
161-
Returns:
162-
The formatted value of the field, represented as a string.
163-
"""
164-
string_value = None
165-
166-
if field_info.annotation == Image and is_image(value):
167-
print("value: ", value)
168-
value = Image(url=encode_image(value))
169-
# print("field info: ", field_info)
170-
# if not isinstance(value, Image):
171-
# print(f"Coerced image: {value}")
172-
# coerced_image = Image(url=encode_image(value))
173-
# print("post coerce: ", coerced_image)
174-
# string_value = json.dumps(_serialize_for_json(coerced_image), ensure_ascii=False)
175-
if isinstance(value, list) and field_info.annotation is str:
176-
# If the field has no special type requirements, format it as a nice numbered list for the LM.
177-
string_value = format_input_list_field_value(value)
178-
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
179-
string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False)
180-
else:
181-
string_value = str(value)
182-
183-
if assume_text:
184-
return string_value
185-
186-
# What we actually want is that for any image inside of any arbitrary normal python or pudantic object, when we see it
187-
# it will trigger some sort of escape sequence that we then combine at the end in order to make it a cohesive request to send to OAI
188-
# Hooking too deep into the serialization process is a bad idea, but we need an escape hatch somewhere
189-
190-
# elif (isinstance(value, Image) or field_info.annotation == Image):
191-
# # This validation should happen somewhere else
192-
# # Safe to import PIL here because it's only imported when an image is actually being formatted
193-
# try:
194-
# import PIL
195-
# except ImportError:
196-
# raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
197-
# image_value = value
198-
# if not isinstance(image_value, Image):
199-
# if isinstance(image_value, dict) and "url" in image_value:
200-
# image_value = image_value["url"]
201-
# elif isinstance(image_value, str) or isinstance(image_value, PIL.Image.Image):
202-
# image_value = encode_image(image_value)
203-
# assert isinstance(image_value, str)
204-
# image_value = Image(url=image_value)
205-
# return {"type": "image_url", "image_url": image_value.model_dump()}
206-
else:
207-
return {"type": "text", "text": string_value}
208-
209-
210117
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
211118
"""
212119
Formats the values of the specified fields according to the field's DSPy type (input or output),
@@ -221,7 +128,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
221128
"""
222129
output = []
223130
for field, field_value in fields_with_values.items():
224-
formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
131+
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
225132
if assume_text:
226133
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
227134
else:
@@ -243,7 +150,7 @@ def parse_value(value, annotation):
243150
parsed_value = value
244151

245152
if isinstance(annotation, enum.EnumMeta):
246-
parsed_value = annotation[value]
153+
return find_enum_member(annotation, value)
247154
elif isinstance(value, str):
248155
try:
249156
parsed_value = json.loads(value)

dspy/adapters/json_adapter.py

Lines changed: 41 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
import ast
2-
import json
32
import enum
43
import inspect
5-
import litellm
6-
import pydantic
4+
import json
75
import textwrap
8-
import json_repair
9-
6+
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin
107

8+
import json_repair
9+
import litellm
10+
import pydantic
1111
from pydantic import TypeAdapter
1212
from pydantic.fields import FieldInfo
13-
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
1413

1514
from dspy.adapters.base import Adapter
15+
from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json
16+
1617
from ..adapters.image_utils import Image
1718
from ..signatures.signature import SignatureMeta
1819
from ..signatures.utils import get_dspy_field_type
1920

21+
2022
class FieldInfoWithName(NamedTuple):
2123
name: str
2224
info: FieldInfo
2325

26+
2427
class JSONAdapter(Adapter):
2528
def __init__(self):
2629
pass
2730

2831
def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
2932
inputs = self.format(signature, demos, inputs)
3033
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
31-
32-
34+
3335
try:
34-
provider = lm.model.split('/', 1)[0] or "openai"
35-
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
36-
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" })
36+
provider = lm.model.split("/", 1)[0] or "openai"
37+
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
38+
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
3739
else:
3840
outputs = lm(**inputs, **lm_kwargs)
3941

@@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
4446

4547
for output in outputs:
4648
value = self.parse(signature, output, _parse_values=_parse_values)
47-
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
49+
assert set(value.keys()) == set(
50+
signature.output_fields.keys()
51+
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
4852
values.append(value)
49-
50-
return values
5153

54+
return values
5255

5356
def format(self, signature, demos, inputs):
5457
messages = []
@@ -71,7 +74,7 @@ def format(self, signature, demos, inputs):
7174
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
7275

7376
messages.append(format_turn(signature, inputs, role="user"))
74-
77+
7578
return messages
7679

7780
def parse(self, signature, completion, _parse_values=True):
@@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True):
9093

9194
def format_turn(self, signature, values, role, incomplete=False):
9295
return format_turn(signature, values, role, incomplete)
93-
96+
9497
def format_fields(self, signature, values, role):
9598
fields_with_values = {
9699
FieldInfoWithName(name=field_name, info=field_info): values.get(
@@ -101,16 +104,16 @@ def format_fields(self, signature, values, role):
101104
}
102105

103106
return format_fields(role=role, fields_with_values=fields_with_values)
104-
107+
105108

106109
def parse_value(value, annotation):
107110
if annotation is str:
108111
return str(value)
109-
112+
110113
parsed_value = value
111114

112115
if isinstance(annotation, enum.EnumMeta):
113-
parsed_value = annotation[value]
116+
parsed_value = find_enum_member(annotation, value)
114117
elif isinstance(value, str):
115118
try:
116119
parsed_value = json.loads(value)
@@ -119,45 +122,10 @@ def parse_value(value, annotation):
119122
parsed_value = ast.literal_eval(value)
120123
except (ValueError, SyntaxError):
121124
parsed_value = value
122-
123-
return TypeAdapter(annotation).validate_python(parsed_value)
124-
125125

126-
def format_blob(blob):
127-
if "\n" not in blob and "«" not in blob and "»" not in blob:
128-
return f"«{blob}»"
129-
130-
modified_blob = blob.replace("\n", "\n ")
131-
return f"«««\n {modified_blob}\n»»»"
132-
133-
134-
def format_input_list_field_value(value: List[Any]) -> str:
135-
"""
136-
Formats the value of an input field of type List[Any].
137-
138-
Args:
139-
value: The value of the list-type input field.
140-
Returns:
141-
A string representation of the input field's list value.
142-
"""
143-
if len(value) == 0:
144-
return "N/A"
145-
if len(value) == 1:
146-
return format_blob(value[0])
147-
148-
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
126+
return TypeAdapter(annotation).validate_python(parsed_value)
149127

150128

151-
def _serialize_for_json(value):
152-
if isinstance(value, pydantic.BaseModel):
153-
return value.model_dump()
154-
elif isinstance(value, list):
155-
return [_serialize_for_json(item) for item in value]
156-
elif isinstance(value, dict):
157-
return {key: _serialize_for_json(val) for key, val in value.items()}
158-
else:
159-
return value
160-
161129
def _format_field_value(field_info: FieldInfo, value: Any) -> str:
162130
"""
163131
Formats the value of the specified field according to the field's DSPy type (input or output),
@@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
169137
Returns:
170138
The formatted value of the field, represented as a string.
171139
"""
172-
173-
if isinstance(value, list) and field_info.annotation is str:
174-
# If the field has no special type requirements, format it as a nice numbere list for the LM.
175-
return format_input_list_field_value(value)
176140
if field_info.annotation is Image:
177141
raise NotImplementedError("Images are not yet supported in JSON mode.")
178-
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
179-
return json.dumps(_serialize_for_json(value))
180-
else:
181-
return str(value)
182142

143+
return format_field_value(field_info=field_info, value=value, assume_text=True)
183144

184145

185146
def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
@@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
197158

198159
if role == "assistant":
199160
d = fields_with_values.items()
200-
d = {k.name: _serialize_for_json(v) for k, v in d}
201-
202-
return json.dumps(_serialize_for_json(d), indent=2)
161+
d = {k.name: v for k, v in d}
162+
return json.dumps(serialize_for_json(d), indent=2)
203163

204164
output = []
205165
for field, field_value in fields_with_values.items():
@@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
246206
field_name, "Not supplied for this particular example."
247207
)
248208
for field_name, field_info in fields.items()
249-
}
209+
},
250210
)
251211
content.append(formatted_fields)
252212

253213
if role == "user":
214+
254215
def type_info(v):
255-
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
256-
if v.annotation is not str else ""
257-
216+
return (
217+
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
218+
if v.annotation is not str
219+
else ""
220+
)
221+
258222
# TODO: Consider if not incomplete:
259223
content.append(
260224
"Respond with a JSON object in the following order of fields: "
@@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta):
297261
def field_metadata(field_name, field_info):
298262
type_ = field_info.annotation
299263

300-
if get_dspy_field_type(field_info) == 'input' or type_ is str:
264+
if get_dspy_field_type(field_info) == "input" or type_ is str:
301265
desc = ""
302266
elif type_ is bool:
303267
desc = "must be True or False"
304268
elif type_ in (int, float):
305269
desc = f"must be a single {type_.__name__} value"
306270
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
307-
desc= f"must be one of: {'; '.join(type_.__members__)}"
308-
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
271+
desc = f"must be one of: {'; '.join(type_.__members__)}"
272+
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal:
309273
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
310274
else:
311275
desc = "must be pareseable according to the following JSON schema: "
@@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
320284
fields_with_values={
321285
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
322286
for field_name, field_info in fields.items()
323-
}
287+
},
324288
)
325-
289+
326290
parts.append("Inputs will have the following structure:")
327-
parts.append(format_signature_fields_for_instructions('user', signature.input_fields))
291+
parts.append(format_signature_fields_for_instructions("user", signature.input_fields))
328292
parts.append("Outputs will be a JSON object with the following fields.")
329-
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields))
293+
parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields))
330294
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
331295

332296
instructions = textwrap.dedent(signature.instructions)

0 commit comments

Comments
 (0)