@@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
122122 ctx_stream[0 ] = < TVMFFIStreamHandle> temp_ptr
123123 temp_args.append(arg)
124124 elif hasattr (arg, " __dlpack__" ):
125- arg = from_dlpack(arg)
125+ ffi_arg = from_dlpack(arg)
126126 out[i].type_index = kTVMFFITensor
127- out[i].v_ptr = (< Tensor> arg).chandle
128- temp_args.append(arg)
127+ out[i].v_ptr = (< Tensor> ffi_arg).chandle
128+ # record the stream from the source framework context when possible
129+ temp_dltensor = TVMFFITensorGetDLTensorPtr((< Tensor> ffi_arg).chandle)
130+ if (temp_dltensor.device.device_type != kDLCPU and
131+ ctx_dev_type != NULL and
132+ ctx_dev_type[0 ] == - 1 ):
133+ # __tvm_ffi_env_stream__ returns the expected stream that should be set
134+ # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function
135+ if hasattr (arg, " __tvm_ffi_env_stream__" ):
136+ # Ideally projects should directly setup their stream context API
137+ # write through by also calling TVMFFIEnvSetCurrentStream
138+ # so we do not need this protocol to do exchange
139+ ctx_dev_type[0 ] = temp_dltensor.device.device_type
140+ ctx_dev_id[0 ] = temp_dltensor.device.device_id
141+ temp_ptr= arg.__tvm_ffi_env_stream__()
142+ ctx_stream[0 ] = < TVMFFIStreamHandle> temp_ptr
143+ temp_args.append(ffi_arg)
129144 elif isinstance (arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None :
130145 arg = arg.__tvm_ffi_object__
131146 out[i].type_index = TVMFFIObjectGetTypeIndex((< Object> arg).chandle)
@@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle,
210225 with nogil:
211226 if ctx_dev_type != - 1 :
212227 # set the stream based on ctx stream
213- c_api_ret_code[0 ] = TVMFFIEnvSetStream (ctx_dev_type, ctx_dev_id, ctx_stream, & prev_stream)
228+ c_api_ret_code[0 ] = TVMFFIEnvSetCurrentStream (ctx_dev_type, ctx_dev_id, ctx_stream, & prev_stream)
214229 if c_api_ret_code[0 ] != 0 :
215230 return 0
216231 c_api_ret_code[0 ] = TVMFFIFunctionCall(
@@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle,
219234 # restore the original stream if it is not the same as the context stream
220235 if ctx_dev_type != - 1 and prev_stream != ctx_stream:
221236 # restore the original stream
222- c_api_ret_code[0 ] = TVMFFIEnvSetStream (ctx_dev_type, ctx_dev_id, prev_stream, NULL )
237+ c_api_ret_code[0 ] = TVMFFIEnvSetCurrentStream (ctx_dev_type, ctx_dev_id, prev_stream, NULL )
223238 if c_api_ret_code[0 ] != 0 :
224239 return 0
225240 return 0
@@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle,
247262
248263 with nogil:
249264 if ctx_dev_type != - 1 :
250- c_api_ret_code[0 ] = TVMFFIEnvSetStream (ctx_dev_type, ctx_dev_id, ctx_stream, & prev_stream)
265+ c_api_ret_code[0 ] = TVMFFIEnvSetCurrentStream (ctx_dev_type, ctx_dev_id, ctx_stream, & prev_stream)
251266 if c_api_ret_code[0 ] != 0 :
252267 return 0
253268 c_api_ret_code[0 ] = TVMFFIFunctionCall(chandle, & packed_args[0 ], nargs, result)
254269 # restore the original stream if it is not the same as the context stream
255270 if ctx_dev_type != - 1 and prev_stream != ctx_stream:
256- c_api_ret_code[0 ] = TVMFFIEnvSetStream (ctx_dev_type, ctx_dev_id, prev_stream, NULL )
271+ c_api_ret_code[0 ] = TVMFFIEnvSetCurrentStream (ctx_dev_type, ctx_dev_id, prev_stream, NULL )
257272 if c_api_ret_code[0 ] != 0 :
258273 return 0
259274
0 commit comments