1010
1111
1212def to_strict_json_schema (model : type [pydantic .BaseModel ]) -> dict [str , Any ]:
13- return _ensure_strict_json_schema (model_json_schema (model ), path = ())
13+ schema = model_json_schema (model )
14+ return _ensure_strict_json_schema (schema , path = (), root = schema )
1415
1516
1617def _ensure_strict_json_schema (
1718 json_schema : object ,
19+ * ,
1820 path : tuple [str , ...],
21+ root : dict [str , object ],
1922) -> dict [str , Any ]:
2023 """Mutates the given JSON schema to ensure it conforms to the `strict` standard
2124 that the API expects.
2225 """
2326 if not is_dict (json_schema ):
2427 raise TypeError (f"Expected { json_schema } to be a dictionary; path={ path } " )
2528
29+ defs = json_schema .get ("$defs" )
30+ if is_dict (defs ):
31+ for def_name , def_schema in defs .items ():
32+ _ensure_strict_json_schema (def_schema , path = (* path , "$defs" , def_name ), root = root )
33+
34+ definitions = json_schema .get ("definitions" )
35+ if is_dict (definitions ):
36+ for definition_name , definition_schema in definitions .items ():
37+ _ensure_strict_json_schema (definition_schema , path = (* path , "definitions" , definition_name ), root = root )
38+
2639 typ = json_schema .get ("type" )
2740 if typ == "object" and "additionalProperties" not in json_schema :
2841 json_schema ["additionalProperties" ] = False
@@ -33,48 +46,80 @@ def _ensure_strict_json_schema(
3346 if is_dict (properties ):
3447 json_schema ["required" ] = [prop for prop in properties .keys ()]
3548 json_schema ["properties" ] = {
36- key : _ensure_strict_json_schema (prop_schema , path = (* path , "properties" , key ))
49+ key : _ensure_strict_json_schema (prop_schema , path = (* path , "properties" , key ), root = root )
3750 for key , prop_schema in properties .items ()
3851 }
3952
4053 # arrays
4154 # { 'type': 'array', 'items': {...} }
4255 items = json_schema .get ("items" )
4356 if is_dict (items ):
44- json_schema ["items" ] = _ensure_strict_json_schema (items , path = (* path , "items" ))
57+ json_schema ["items" ] = _ensure_strict_json_schema (items , path = (* path , "items" ), root = root )
4558
4659 # unions
4760 any_of = json_schema .get ("anyOf" )
4861 if is_list (any_of ):
4962 json_schema ["anyOf" ] = [
50- _ensure_strict_json_schema (variant , path = (* path , "anyOf" , str (i ))) for i , variant in enumerate (any_of )
63+ _ensure_strict_json_schema (variant , path = (* path , "anyOf" , str (i )), root = root )
64+ for i , variant in enumerate (any_of )
5165 ]
5266
5367 # intersections
5468 all_of = json_schema .get ("allOf" )
5569 if is_list (all_of ):
5670 if len (all_of ) == 1 :
57- json_schema .update (_ensure_strict_json_schema (all_of [0 ], path = (* path , "allOf" , "0" )))
71+ json_schema .update (_ensure_strict_json_schema (all_of [0 ], path = (* path , "allOf" , "0" ), root = root ))
5872 json_schema .pop ("allOf" )
5973 else :
6074 json_schema ["allOf" ] = [
61- _ensure_strict_json_schema (entry , path = (* path , "allOf" , str (i ))) for i , entry in enumerate (all_of )
75+ _ensure_strict_json_schema (entry , path = (* path , "allOf" , str (i )), root = root )
76+ for i , entry in enumerate (all_of )
6277 ]
6378
64- defs = json_schema .get ("$defs" )
65- if is_dict (defs ):
66- for def_name , def_schema in defs .items ():
67- _ensure_strict_json_schema (def_schema , path = (* path , "$defs" , def_name ))
79+ # we can't use `$ref`s if there are also other properties defined, e.g.
80+ # `{"$ref": "...", "description": "my description"}`
81+ #
82+ # so we unravel the ref
83+ # `{"type": "string", "description": "my description"}`
84+ ref = json_schema .get ("$ref" )
85+ if ref and has_more_than_n_keys (json_schema , 1 ):
86+ assert isinstance (ref , str ), f"Received non-string $ref - { ref } "
6887
69- definitions = json_schema .get ("definitions" )
70- if is_dict (definitions ):
71- for definition_name , definition_schema in definitions .items ():
72- _ensure_strict_json_schema (definition_schema , path = (* path , "definitions" , definition_name ))
88+ resolved = resolve_ref (root = root , ref = ref )
89+ if not is_dict (resolved ):
90+ raise ValueError (f"Expected `$ref: { ref } ` to resolved to a dictionary but got { resolved } " )
91+
92+ # properties from the json schema take priority over the ones on the `$ref`
93+ json_schema .update ({** resolved , ** json_schema })
94+ json_schema .pop ("$ref" )
7395
7496 return json_schema
7597
7698
99+ def resolve_ref (* , root : dict [str , object ], ref : str ) -> object :
100+ if not ref .startswith ("#/" ):
101+ raise ValueError (f"Unexpected $ref format { ref !r} ; Does not start with #/" )
102+
103+ path = ref [2 :].split ("/" )
104+ resolved = root
105+ for key in path :
106+ value = resolved [key ]
107+ assert is_dict (value ), f"encountered non-dictionary entry while resolving { ref } - { resolved } "
108+ resolved = value
109+
110+ return resolved
111+
112+
77113def is_dict (obj : object ) -> TypeGuard [dict [str , object ]]:
78114 # just pretend that we know there are only `str` keys
79115 # as that check is not worth the performance cost
80116 return _is_dict (obj )
117+
118+
119+ def has_more_than_n_keys (obj : dict [str , object ], n : int ) -> bool :
120+ i = 0
121+ for _ in obj .keys ():
122+ i += 1
123+ if i > n :
124+ return True
125+ return False
0 commit comments