44"""
55import json
66import logging
7+ from typing import Optional
78
89import numpy as np
910
@@ -104,24 +105,43 @@ def _from_json(self, body_list):
104105 logger .debug ("Bytes array is %s" , body_list )
105106
106107 input_names = []
107- for index , input in enumerate (body_list [0 ]["inputs" ]):
108- if input ["datatype" ] == "BYTES" :
109- body_list [0 ]["inputs" ][index ]["data" ] = input ["data" ][0 ]
110- else :
111- body_list [0 ]["inputs" ][index ]["data" ] = (
112- np .array (input ["data" ]).reshape (tuple (input ["shape" ])).tolist ()
113- )
114- input_names .append (input ["name" ])
108+ parameters = []
109+ ids = []
110+ input_parameters = []
111+ data_list = []
112+
113+ for body in body_list :
114+ id = body .get ("id" )
115+ ids .append (id )
116+ params = body .get ("parameters" )
117+ if params :
118+ parameters .append (params )
119+ inp_names = []
120+ inp_params = []
121+ for i , input in enumerate (body ["inputs" ]):
122+ params = input .get ("parameters" )
123+ if params :
124+ inp_params .append (params )
125+ if input ["datatype" ] == "BYTES" :
126+ body ["inputs" ][i ]["data" ] = input ["data" ][0 ]
127+ else :
128+ body ["inputs" ][i ]["data" ] = (
129+ np .array (input ["data" ]).reshape (tuple (input ["shape" ])).tolist ()
130+ )
131+ inp_names .append (input ["name" ])
132+ data = body ["inputs" ] if len (body ["inputs" ]) > 1 else body ["inputs" ][0 ]
133+ data_list .append (data )
134+
135+ input_parameters .append (inp_params )
136+ input_names .append (inp_names )
137+
138+ setattr (self .context , "input_request_id" , ids )
115139 setattr (self .context , "input_names" , input_names )
116- logger .debug ("Bytes array is %s" , body_list )
117- id = body_list [0 ].get ("id" )
118- if id and id .strip ():
119- setattr (self .context , "input_request_id" , body_list [0 ]["id" ])
120- # TODO: Add parameters support
121- # parameters = body_list[0].get("parameters")
122- # if parameters:
123- # setattr(self.context, "input_parameters", body_list[0]["parameters"])
124- data_list = [inputs_list .get ("inputs" ) for inputs_list in body_list ][0 ]
140+ setattr (self .context , "request_parameters" , parameters )
141+ setattr (self .context , "input_parameters" , input_parameters )
142+ logger .debug ("Data array is %s" , data_list )
143+ logger .debug ("Request paraemeters array is %s" , parameters )
144+ logger .debug ("Input parameters is %s" , input_parameters )
125145 return data_list
126146
127147 def format_output (self , data ):
@@ -145,41 +165,48 @@ def format_output(self, data):
145165
146166 """
147167 logger .debug ("The Response of KServe v2 format %s" , data )
148- response = {}
149- if hasattr (self .context , "input_request_id" ):
150- response ["id" ] = getattr (self .context , "input_request_id" )
151- delattr (self .context , "input_request_id" )
152- else :
153- response ["id" ] = self .context .get_request_id (0 )
154- # TODO: Add parameters support
155- # if hasattr(self.context, "input_parameters"):
156- # response["parameters"] = getattr(self.context, "input_parameters")
157- # delattr(self.context, "input_parameters")
158- response ["model_name" ] = self .context .manifest .get ("model" ).get ("modelName" )
159- response ["model_version" ] = self .context .manifest .get ("model" ).get (
160- "modelVersion"
161- )
162- response ["outputs" ] = self ._batch_to_json (data )
163- return [response ]
164-
165- def _batch_to_json (self , data ):
168+ return self ._batch_to_json (data )
169+
170+ def _batch_to_json (self , batch : dict ):
166171 """
167172 Splits batch output to json objects
168173 """
169- output = []
170- input_names = getattr (self .context , "input_names" )
174+ parameters = getattr (self .context , "request_parameters" )
175+ ids = getattr (self .context , "input_request_id" )
176+ input_parameters = getattr (self .context , "input_parameters" )
177+ responses = []
178+ for index , data in enumerate (batch ):
179+ response = {}
180+ response ["id" ] = ids [index ] or self .context .get_request_id (index )
181+ if parameters and parameters [index ]:
182+ response ["parameters" ] = parameters [index ]
183+ response ["model_name" ] = self .context .manifest .get ("model" ).get ("modelName" )
184+ response ["model_version" ] = self .context .manifest .get ("model" ).get (
185+ "modelVersion"
186+ )
187+ outputs = []
188+ if isinstance (data , dict ):
189+ for key , item in data .items ():
190+ outputs .append (self ._to_json (item , key , input_parameters ))
191+ else :
192+ outputs .append (self ._to_json (data , "predictions" , input_parameters ))
193+ response ["outputs" ] = outputs
194+ responses .append (response )
171195 delattr (self .context , "input_names" )
172- for index , item in enumerate (data ):
173- output .append (self ._to_json (item , input_names [index ]))
174- return output
196+ delattr (self .context , "input_request_id" )
197+ delattr (self .context , "input_parameters" )
198+ delattr (self .context , "request_parameters" )
199+ return responses
175200
176- def _to_json (self , data , input_name ):
201+ def _to_json (self , data , output_name , parameters : Optional [ list ] = None ):
177202 """
178203 Constructs JSON object from data
179204 """
180205 output_data = {}
181206 data_ndarray = np .array (data ).flatten ()
182- output_data ["name" ] = input_name
207+ output_data ["name" ] = output_name
208+ if parameters :
209+ output_data ["parameters" ] = parameters
183210 output_data ["datatype" ] = _to_datatype (data_ndarray .dtype )
184211 output_data ["data" ] = data_ndarray .tolist ()
185212 output_data ["shape" ] = data_ndarray .flatten ().shape
0 commit comments