11import ast
2- import json
32import enum
43import inspect
5- import litellm
6- import pydantic
4+ import json
75import textwrap
8- import json_repair
9-
6+ from typing import Any , Dict , KeysView , Literal , NamedTuple , get_args , get_origin
107
8+ import json_repair
9+ import litellm
10+ import pydantic
1111from pydantic import TypeAdapter
1212from pydantic .fields import FieldInfo
13- from typing import Any , Dict , KeysView , List , Literal , NamedTuple , get_args , get_origin
1413
1514from dspy .adapters .base import Adapter
15+ from dspy .adapters .utils import find_enum_member , format_field_value , serialize_for_json
16+
1617from ..adapters .image_utils import Image
1718from ..signatures .signature import SignatureMeta
1819from ..signatures .utils import get_dspy_field_type
1920
21+
2022class FieldInfoWithName (NamedTuple ):
2123 name : str
2224 info : FieldInfo
2325
26+
2427class JSONAdapter (Adapter ):
2528 def __init__ (self ):
2629 pass
2730
2831 def __call__ (self , lm , lm_kwargs , signature , demos , inputs , _parse_values = True ):
2932 inputs = self .format (signature , demos , inputs )
3033 inputs = dict (prompt = inputs ) if isinstance (inputs , str ) else dict (messages = inputs )
31-
32-
34+
3335 try :
34- provider = lm .model .split ('/' , 1 )[0 ] or "openai"
35- if ' response_format' in litellm .get_supported_openai_params (model = lm .model , custom_llm_provider = provider ):
36- outputs = lm (** inputs , ** lm_kwargs , response_format = { "type" : "json_object" })
36+ provider = lm .model .split ("/" , 1 )[0 ] or "openai"
37+ if " response_format" in litellm .get_supported_openai_params (model = lm .model , custom_llm_provider = provider ):
38+ outputs = lm (** inputs , ** lm_kwargs , response_format = {"type" : "json_object" })
3739 else :
3840 outputs = lm (** inputs , ** lm_kwargs )
3941
@@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
4446
4547 for output in outputs :
4648 value = self .parse (signature , output , _parse_values = _parse_values )
47- assert set (value .keys ()) == set (signature .output_fields .keys ()), f"Expected { signature .output_fields .keys ()} but got { value .keys ()} "
49+ assert set (value .keys ()) == set (
50+ signature .output_fields .keys ()
51+ ), f"Expected { signature .output_fields .keys ()} but got { value .keys ()} "
4852 values .append (value )
49-
50- return values
5153
54+ return values
5255
5356 def format (self , signature , demos , inputs ):
5457 messages = []
@@ -71,7 +74,7 @@ def format(self, signature, demos, inputs):
7174 messages .append (format_turn (signature , demo , role = "assistant" , incomplete = demo in incomplete_demos ))
7275
7376 messages .append (format_turn (signature , inputs , role = "user" ))
74-
77+
7578 return messages
7679
7780 def parse (self , signature , completion , _parse_values = True ):
@@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True):
9093
9194 def format_turn (self , signature , values , role , incomplete = False ):
9295 return format_turn (signature , values , role , incomplete )
93-
96+
9497 def format_fields (self , signature , values , role ):
9598 fields_with_values = {
9699 FieldInfoWithName (name = field_name , info = field_info ): values .get (
@@ -101,16 +104,16 @@ def format_fields(self, signature, values, role):
101104 }
102105
103106 return format_fields (role = role , fields_with_values = fields_with_values )
104-
107+
105108
106109def parse_value (value , annotation ):
107110 if annotation is str :
108111 return str (value )
109-
112+
110113 parsed_value = value
111114
112115 if isinstance (annotation , enum .EnumMeta ):
113- parsed_value = annotation [ value ]
116+ parsed_value = find_enum_member ( annotation , value )
114117 elif isinstance (value , str ):
115118 try :
116119 parsed_value = json .loads (value )
@@ -119,45 +122,10 @@ def parse_value(value, annotation):
119122 parsed_value = ast .literal_eval (value )
120123 except (ValueError , SyntaxError ):
121124 parsed_value = value
122-
123- return TypeAdapter (annotation ).validate_python (parsed_value )
124-
125125
126- def format_blob (blob ):
127- if "\n " not in blob and "«" not in blob and "»" not in blob :
128- return f"«{ blob } »"
129-
130- modified_blob = blob .replace ("\n " , "\n " )
131- return f"«««\n { modified_blob } \n »»»"
132-
133-
134- def format_input_list_field_value (value : List [Any ]) -> str :
135- """
136- Formats the value of an input field of type List[Any].
137-
138- Args:
139- value: The value of the list-type input field.
140- Returns:
141- A string representation of the input field's list value.
142- """
143- if len (value ) == 0 :
144- return "N/A"
145- if len (value ) == 1 :
146- return format_blob (value [0 ])
147-
148- return "\n " .join ([f"[{ idx + 1 } ] { format_blob (txt )} " for idx , txt in enumerate (value )])
126+ return TypeAdapter (annotation ).validate_python (parsed_value )
149127
150128
151- def _serialize_for_json (value ):
152- if isinstance (value , pydantic .BaseModel ):
153- return value .model_dump ()
154- elif isinstance (value , list ):
155- return [_serialize_for_json (item ) for item in value ]
156- elif isinstance (value , dict ):
157- return {key : _serialize_for_json (val ) for key , val in value .items ()}
158- else :
159- return value
160-
161129def _format_field_value (field_info : FieldInfo , value : Any ) -> str :
162130 """
163131 Formats the value of the specified field according to the field's DSPy type (input or output),
@@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
169137 Returns:
170138 The formatted value of the field, represented as a string.
171139 """
172-
173- if isinstance (value , list ) and field_info .annotation is str :
174- # If the field has no special type requirements, format it as a nice numbere list for the LM.
175- return format_input_list_field_value (value )
176140 if field_info .annotation is Image :
177141 raise NotImplementedError ("Images are not yet supported in JSON mode." )
178- elif isinstance (value , pydantic .BaseModel ) or isinstance (value , dict ) or isinstance (value , list ):
179- return json .dumps (_serialize_for_json (value ))
180- else :
181- return str (value )
182142
143+ return format_field_value (field_info = field_info , value = value , assume_text = True )
183144
184145
185146def format_fields (role : str , fields_with_values : Dict [FieldInfoWithName , Any ]) -> str :
@@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
197158
198159 if role == "assistant" :
199160 d = fields_with_values .items ()
200- d = {k .name : _serialize_for_json (v ) for k , v in d }
201-
202- return json .dumps (_serialize_for_json (d ), indent = 2 )
161+ d = {k .name : v for k , v in d }
162+ return json .dumps (serialize_for_json (d ), indent = 2 )
203163
204164 output = []
205165 for field , field_value in fields_with_values .items ():
@@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
246206 field_name , "Not supplied for this particular example."
247207 )
248208 for field_name , field_info in fields .items ()
249- }
209+ },
250210 )
251211 content .append (formatted_fields )
252212
253213 if role == "user" :
214+
254215 def type_info (v ):
255- return f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )" \
256- if v .annotation is not str else ""
257-
216+ return (
217+ f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )"
218+ if v .annotation is not str
219+ else ""
220+ )
221+
258222 # TODO: Consider if not incomplete:
259223 content .append (
260224 "Respond with a JSON object in the following order of fields: "
@@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta):
297261 def field_metadata (field_name , field_info ):
298262 type_ = field_info .annotation
299263
300- if get_dspy_field_type (field_info ) == ' input' or type_ is str :
264+ if get_dspy_field_type (field_info ) == " input" or type_ is str :
301265 desc = ""
302266 elif type_ is bool :
303267 desc = "must be True or False"
304268 elif type_ in (int , float ):
305269 desc = f"must be a single { type_ .__name__ } value"
306270 elif inspect .isclass (type_ ) and issubclass (type_ , enum .Enum ):
307- desc = f"must be one of: { '; ' .join (type_ .__members__ )} "
308- elif hasattr (type_ , ' __origin__' ) and type_ .__origin__ is Literal :
271+ desc = f"must be one of: { '; ' .join (type_ .__members__ )} "
272+ elif hasattr (type_ , " __origin__" ) and type_ .__origin__ is Literal :
309273 desc = f"must be one of: { '; ' .join ([str (x ) for x in type_ .__args__ ])} "
310274 else :
311275 desc = "must be pareseable according to the following JSON schema: "
@@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
320284 fields_with_values = {
321285 FieldInfoWithName (name = field_name , info = field_info ): field_metadata (field_name , field_info )
322286 for field_name , field_info in fields .items ()
323- }
287+ },
324288 )
325-
289+
326290 parts .append ("Inputs will have the following structure:" )
327- parts .append (format_signature_fields_for_instructions (' user' , signature .input_fields ))
291+ parts .append (format_signature_fields_for_instructions (" user" , signature .input_fields ))
328292 parts .append ("Outputs will be a JSON object with the following fields." )
329- parts .append (format_signature_fields_for_instructions (' assistant' , signature .output_fields ))
293+ parts .append (format_signature_fields_for_instructions (" assistant" , signature .output_fields ))
330294 # parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
331295
332296 instructions = textwrap .dedent (signature .instructions )
0 commit comments