11
22from  typing  import  Type , Dict , Optional , List , Tuple , Any , Union 
3- from  pydantic  import  BaseModel , confloat 
3+ from  pydantic  import  BaseModel , confloat , Field 
4+ from  label_studio_sdk .label_interface .objects  import  PredictionValue 
5+ from  typing  import  Union , List 
46
5- from  label_studio_sdk .objects  import  PredictionValue 
7+ 
8+ # one or multiple predictions per task 
9+ SingleTaskPredictions  =  Union [List [PredictionValue ], PredictionValue ]
610
711
812class  ModelResponse (BaseModel ):
913    """ 
1014    """ 
1115    model_version : Optional [str ] =  None 
12-     predictions : List [PredictionValue ]
16+     predictions : List [SingleTaskPredictions ]
1317
1418    def  has_model_version (self ) ->  bool :
1519        return  bool (self .model_version )
@@ -18,21 +22,16 @@ def update_predictions_version(self) -> None:
1822        """ 
1923        """ 
2024        for  prediction  in  self .predictions :
21-             if  not  prediction .model_version :
22-                 prediction .model_version  =  self .model_version 
25+             if  isinstance (prediction , PredictionValue ):
26+                 prediction  =  [prediction ]
27+             for  p  in  prediction :
28+                 if  not  p .model_version :
29+                     p .model_version  =  self .model_version 
2330
2431    def  set_version (self , version : str ) ->  None :
2532        """ 
2633        """ 
2734        self .model_version  =  version 
2835        # Set the version for each prediction 
2936        self .update_predictions_version ()
30- 
31-     def  serialize (self ):
32-         """ 
33-         """ 
34-         return  {
35-             "model_version" : self .model_version ,
36-             "predictions" : [ p .serialize () for  p  in  self .predictions  ]
37-         }
3837
0 commit comments