Skip to content

Commit 73b6851

Browse files
authored
[FFI][ABI] Introduce generic stream exchange protocol (#18295)
This PR adds a __tvm_ffi_env_stream__ protocol for generic tensors to exchange env stream to tvm ffi. Also renames TVMFFIEnvSetStream to TVMFFIEnvSetCurrentStream.
1 parent cf80a82 commit 73b6851

File tree

8 files changed

+134
-56
lines changed

8 files changed

+134
-56
lines changed

ffi/include/tvm/ffi/extra/c_env_api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ typedef void* TVMFFIStreamHandle;
4949
* \note The stream is a weak reference that is cached/owned by the module.
5050
* \return 0 when success, nonzero when failure happens
5151
*/
52-
TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
53-
TVMFFIStreamHandle stream,
54-
TVMFFIStreamHandle* opt_out_original_stream);
52+
TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
53+
TVMFFIStreamHandle stream,
54+
TVMFFIStreamHandle* opt_out_original_stream);
5555

5656
/*!
5757
* \brief FFI function to get the current stream for a device

ffi/python/tvm_ffi/cython/base.pxi

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,24 @@ from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release,
2424
from cpython cimport pycapsule, PyCapsule_Destructor
2525
from cpython cimport PyErr_SetNone
2626

27-
28-
# Cython binding for TVM FFI C API
29-
cdef extern from "tvm/ffi/c_api.h":
30-
cdef enum TVMFFITypeIndex:
31-
kTVMFFIAny = -1
32-
kTVMFFINone = 0
33-
kTVMFFIInt = 1
34-
kTVMFFIBool = 2
35-
kTVMFFIFloat = 3
36-
kTVMFFIOpaquePtr = 4
37-
kTVMFFIDataType = 5
38-
kTVMFFIDevice = 6
39-
kTVMFFIDLTensorPtr = 7
40-
kTVMFFIRawStr = 8
41-
kTVMFFIByteArrayPtr = 9
42-
kTVMFFIObjectRValueRef = 10
43-
kTVMFFISmallStr = 11
44-
kTVMFFISmallBytes = 12
45-
kTVMFFIStaticObjectBegin = 64
46-
kTVMFFIObject = 64
47-
kTVMFFIStr = 65
48-
kTVMFFIBytes = 66
49-
kTVMFFIError = 67
50-
kTVMFFIFunction = 68
51-
kTVMFFIShape = 69
52-
kTVMFFITensor = 70
53-
kTVMFFIArray = 71
54-
kTVMFFIMap = 72
55-
kTVMFFIModule = 73
56-
kTVMFFIOpaquePyObject = 74
57-
58-
59-
ctypedef void* TVMFFIObjectHandle
27+
cdef extern from "dlpack/dlpack.h":
28+
cdef enum:
29+
kDLCPU = 1,
30+
kDLCUDA = 2,
31+
kDLCUDAHost = 3,
32+
kDLOpenCL = 4,
33+
kDLVulkan = 7,
34+
kDLMetal = 8,
35+
kDLVPI = 9,
36+
kDLROCM = 10,
37+
kDLROCMHost = 11,
38+
kDLExtDev = 12,
39+
kDLCUDAManaged = 13,
40+
kDLOneAPI = 14,
41+
kDLWebGPU = 15,
42+
kDLHexagon = 16,
43+
kDLMAIA = 17
44+
kDLTrn = 18
6045

6146
ctypedef struct DLDataType:
6247
uint8_t code
@@ -92,6 +77,40 @@ cdef extern from "tvm/ffi/c_api.h":
9277
void (*deleter)(DLManagedTensorVersioned* self)
9378
uint64_t flags
9479

80+
81+
# Cython binding for TVM FFI C API
82+
cdef extern from "tvm/ffi/c_api.h":
83+
cdef enum TVMFFITypeIndex:
84+
kTVMFFIAny = -1
85+
kTVMFFINone = 0
86+
kTVMFFIInt = 1
87+
kTVMFFIBool = 2
88+
kTVMFFIFloat = 3
89+
kTVMFFIOpaquePtr = 4
90+
kTVMFFIDataType = 5
91+
kTVMFFIDevice = 6
92+
kTVMFFIDLTensorPtr = 7
93+
kTVMFFIRawStr = 8
94+
kTVMFFIByteArrayPtr = 9
95+
kTVMFFIObjectRValueRef = 10
96+
kTVMFFISmallStr = 11
97+
kTVMFFISmallBytes = 12
98+
kTVMFFIStaticObjectBegin = 64
99+
kTVMFFIObject = 64
100+
kTVMFFIStr = 65
101+
kTVMFFIBytes = 66
102+
kTVMFFIError = 67
103+
kTVMFFIFunction = 68
104+
kTVMFFIShape = 69
105+
kTVMFFITensor = 70
106+
kTVMFFIArray = 71
107+
kTVMFFIMap = 72
108+
kTVMFFIModule = 73
109+
kTVMFFIOpaquePyObject = 74
110+
111+
112+
ctypedef void* TVMFFIObjectHandle
113+
95114
ctypedef struct TVMFFIObject:
96115
int32_t type_index
97116
int32_t ref_counter
@@ -219,9 +238,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
219238

220239
int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
221240
void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil
222-
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
223-
TVMFFIStreamHandle stream,
224-
TVMFFIStreamHandle* opt_out_original_stream) nogil
241+
int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
242+
TVMFFIStreamHandle stream,
243+
TVMFFIStreamHandle* opt_out_original_stream) nogil
225244

226245

227246
cdef class ByteArrayArg:

ffi/python/tvm_ffi/cython/function.pxi

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
122122
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
123123
temp_args.append(arg)
124124
elif hasattr(arg, "__dlpack__"):
125-
arg = from_dlpack(arg)
125+
ffi_arg = from_dlpack(arg)
126126
out[i].type_index = kTVMFFITensor
127-
out[i].v_ptr = (<Tensor>arg).chandle
128-
temp_args.append(arg)
127+
out[i].v_ptr = (<Tensor>ffi_arg).chandle
128+
# record the stream from the source framework context when possible
129+
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>ffi_arg).chandle)
130+
if (temp_dltensor.device.device_type != kDLCPU and
131+
ctx_dev_type != NULL and
132+
ctx_dev_type[0] == -1):
133+
# __tvm_ffi_env_stream__ returns the expected stream that should be set
134+
# through TVMFFIEnvSetCurrentStream when calling a TVM FFI function
135+
if hasattr(arg, "__tvm_ffi_env_stream__"):
136+
# Ideally projects should directly setup their stream context API
137+
# write through by also calling TVMFFIEnvSetCurrentStream
138+
# so we do not need this protocol to do exchange
139+
ctx_dev_type[0] = temp_dltensor.device.device_type
140+
ctx_dev_id[0] = temp_dltensor.device.device_id
141+
temp_ptr= arg.__tvm_ffi_env_stream__()
142+
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
143+
temp_args.append(ffi_arg)
129144
elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
130145
arg = arg.__tvm_ffi_object__
131146
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
@@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle,
210225
with nogil:
211226
if ctx_dev_type != -1:
212227
# set the stream based on ctx stream
213-
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
228+
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
214229
if c_api_ret_code[0] != 0:
215230
return 0
216231
c_api_ret_code[0] = TVMFFIFunctionCall(
@@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle,
219234
# restore the original stream if it is not the same as the context stream
220235
if ctx_dev_type != -1 and prev_stream != ctx_stream:
221236
# restore the original stream
222-
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
237+
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
223238
if c_api_ret_code[0] != 0:
224239
return 0
225240
return 0
@@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle,
247262

248263
with nogil:
249264
if ctx_dev_type != -1:
250-
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
265+
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
251266
if c_api_ret_code[0] != 0:
252267
return 0
253268
c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result)
254269
# restore the original stream if it is not the same as the context stream
255270
if ctx_dev_type != -1 and prev_stream != ctx_stream:
256-
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
271+
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
257272
if c_api_ret_code[0] != 0:
258273
return 0
259274

ffi/python/tvm_ffi/cython/tensor.pxi

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,30 @@ _set_class_tensor(Tensor)
260260
_register_object_by_index(kTVMFFITensor, Tensor)
261261

262262

263+
cdef class DLTensorTestWrapper:
264+
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose.
265+
"""
266+
cdef Tensor tensor
267+
def __init__(self, tensor):
268+
self.tensor = tensor
269+
270+
def __tvm_ffi_env_stream__(self):
271+
cdef TVMFFIStreamHandle stream
272+
cdef long long stream_as_int
273+
cdef int c_api_ret_code
274+
with nogil:
275+
stream = TVMFFIEnvGetCurrentStream(
276+
self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id)
277+
stream_as_int = <long long>stream
278+
return stream_as_int
279+
280+
def __dlpack_device__(self):
281+
return self.tensor.__dlpack_device__()
282+
283+
def __dlpack__(self, *, **kwargs):
284+
return self.tensor.__dlpack__(**kwargs)
285+
286+
263287
cdef inline object make_ret_dltensor(TVMFFIAny result):
264288
cdef DLTensor* dltensor
265289
dltensor = <DLTensor*>result.v_ptr

ffi/scripts/benchmark_dlpack.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@
4444

4545

4646
def print_speed(name, speed):
47-
print(f"{name:<40} {speed} sec/call")
47+
print(f"{name:<60} {speed} sec/call")
4848

4949

5050
def print_error(name, error):
51-
print(f"{name:<40} {error}")
51+
print(f"{name:<60} {error}")
5252

5353

5454
def baseline_torch_add(repeat):
@@ -122,7 +122,7 @@ def tvm_ffi_nop(repeat):
122122
nop(x, y, z)
123123
start = time.time()
124124
for i in range(repeat):
125-
y = tvm_ffi.from_dlpack(x)
125+
nop(x, y, z)
126126
end = time.time()
127127
print_speed("tvm_ffi.nop", (end - start) / repeat)
128128

@@ -275,6 +275,22 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
275275
bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, repeat)
276276

277277

278+
def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
279+
"""
280+
Measures overhead of running dlpack via auto convert by directly
281+
take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi.
282+
"""
283+
x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
284+
y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
285+
z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
286+
x = tvm_ffi.core.DLTensorTestWrapper(x)
287+
y = tvm_ffi.core.DLTensorTestWrapper(y)
288+
z = tvm_ffi.core.DLTensorTestWrapper(z)
289+
bench_tvm_ffi_nop_autodlpack(
290+
f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat
291+
)
292+
293+
278294
def bench_to_dlpack(x, name, repeat):
279295
x.__dlpack__()
280296
start = time.time()
@@ -367,7 +383,6 @@ def main():
367383
baseline_numpy_add(repeat)
368384
baseline_torch_add(repeat)
369385
baseline_cupy_add(repeat)
370-
tvm_ffi_nop(repeat)
371386
tvm_ffi_nop_from_torch_dlpack(repeat)
372387
tvm_ffi_nop_from_numpy_dlpack(repeat)
373388
tvm_ffi_self_dlpack_nop(repeat)
@@ -377,6 +392,9 @@ def main():
377392
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)
378393

379394
tvm_ffi_nop_autodlpack_from_numpy(repeat)
395+
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
396+
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
397+
tvm_ffi_nop(repeat)
380398
print("-------------------------------")
381399
print("Benchmark x.__dlpack__ overhead")
382400
print("-------------------------------")

ffi/src/ffi/extra/stream_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class StreamContext {
6666
} // namespace ffi
6767
} // namespace tvm
6868

69-
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
70-
TVMFFIStreamHandle* out_original_stream) {
69+
int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
70+
TVMFFIStreamHandle* out_original_stream) {
7171
TVM_FFI_SAFE_CALL_BEGIN();
7272
tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream,
7373
out_original_stream);

src/runtime/device_api.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }
165165
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
166166

167167
void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
168-
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr));
168+
TVM_FFI_CHECK_SAFE_CALL(
169+
TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, nullptr));
169170
}
170171

171172
TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {

src/runtime/vm/cuda/cuda_graph_builtin.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,14 @@ class CUDACaptureStream {
118118
explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
119119
CUDA_CALL(cudaGetDevice(&device_id_));
120120
TVM_FFI_CHECK_SAFE_CALL(
121-
TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
122-
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
121+
TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_,
122+
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
123123
CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
124124
}
125125
~CUDACaptureStream() noexcept(false) {
126126
cudaStreamEndCapture(capture_stream_, output_graph_);
127-
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr));
127+
TVM_FFI_CHECK_SAFE_CALL(
128+
TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, nullptr));
128129
}
129130

130131
private:

0 commit comments

Comments
 (0)