Skip to content

Commit 8edd90b

Browse files
committed
[FFI][ABI][REFACTOR] Enhance DLPack Exchange Speed and Bheavior
This PR enhances DLPack exchange by introducing DLPackPyObjectExporter, DLPackPyObjectImporter and DLPackTensorAllocator. These three function pointers will help us to speedup import/export with DLPack and also streamline the rare(but still useful sometimes) allocation inside the FFI. They can help to significantly speedup autodlpack import. They will also enable us to be able to query the allocator from env and return ffi::Tensor back to the caller environment(experimental), when a function takes torch.Tensor as argument, returned Tensor values will be converted to torch.Tensor. Also renames SetCurrentStream => SetStream to align with styles in CUDA API.
1 parent 85dc1d7 commit 8edd90b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1730
-251
lines changed

ffi/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ set(tvm_ffi_extra_objs_sources
7373
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc"
7474
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc"
7575
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc"
76-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc"
76+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_context.cc"
7777
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc"
7878
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc"
7979
)
@@ -249,6 +249,7 @@ endif()
249249

250250
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION include/tvm/ffi/)
251251
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/ DESTINATION include/)
252+
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h DESTINATION include/)
252253
install(TARGETS tvm_ffi_shared DESTINATION lib)
253254
# ship additional dSYM files for debugging symbols on if available
254255
if (APPLE)

ffi/docs/get_started/quick_start.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void AddOneCUDA(DLTensor* x, DLTensor* y) {
125125
126126
// Get current CUDA stream from environment
127127
cudaStream_t stream = static_cast<cudaStream_t>(
128-
TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id));
128+
TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
129129
130130
// Launch kernel
131131
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(
@@ -136,7 +136,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
136136
```
137137

138138
**Key Points:**
139-
- We use `TVMFFIEnvGetCurrentStream` to obtain the current stream from the environement
139+
- We use `TVMFFIEnvGetStream` to obtain the current stream from the environement
140140
- When invoking ffi Function from python end with PyTorch tensor as argument,
141141
the stream will be populated with torch's current stream.
142142

ffi/examples/inline_module/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main():
6363
// it will be set to torch.cuda.current_stream() when calling the function
6464
// with torch.Tensors
6565
cudaStream_t stream = static_cast<cudaStream_t>(
66-
TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id));
66+
TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
6767
// launch the kernel
6868
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
6969
static_cast<float*>(y->data), n);

ffi/examples/quick_start/run_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run_add_one_cuda():
6464
with torch.cuda.stream(stream):
6565
# tvm-ffi automatically handles DLPack compatible tensors
6666
# it also handles interactions with torch runtime
67-
# torch.cuda.current_stream() will be set and available via TVMFFIEnvGetCurrentStream
67+
# torch.cuda.current_stream() will be set and available via TVMFFIEnvGetStream
6868
# when calling the function
6969
mod.add_one_cuda(x, y)
7070
stream.synchronize()

ffi/examples/quick_start/src/add_one_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
4646
// Obtain the current stream from the environment
4747
// it will be set to torch.cuda.current_stream() when calling the function
4848
// with torch.Tensors
49-
cudaStream_t stream = static_cast<cudaStream_t>(
50-
TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id));
49+
cudaStream_t stream =
50+
static_cast<cudaStream_t>(TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
5151
// launch the kernel
5252
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
5353
static_cast<float*>(y->data), n);

ffi/include/tvm/ffi/c_api.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@
2727
#include <dlpack/dlpack.h>
2828
#include <stdint.h>
2929

30+
/*
31+
* \brief C-style Allocator that allocates memory for a DLPack tensor.
32+
* \param prototype The prototype DLTensor to offer details about device and shape.
33+
* \param out The output DLManagedTensorVersioned.
34+
* \param error_ctx The context to set the error.
35+
* \param SetError The function to set the error.
36+
* \return 0 on success, -1 on failure.
37+
* call SetError(error_ctx, kind, message) to set the error kind and message.
38+
* \note Error propagation via SetError.
39+
*/
40+
typedef int (*DLPackTensorAllocator)( //
41+
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, //
42+
void (*SetError)(void* error_ctx, const char* kind, const char* message) //
43+
);
44+
3045
// Macros to do weak linking
3146
#ifdef _MSC_VER
3247
#define TVM_FFI_WEAK __declspec(selectany)

ffi/include/tvm/ffi/container/tensor.h

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,60 @@ class Tensor : public ObjectRef {
341341
return Tensor(make_object<details::TensorObjFromNDAlloc<TNDAlloc>>(
342342
alloc, shape, dtype, device, std::forward<ExtraArgs>(extra_args)...));
343343
}
344-
344+
/*!
345+
* \brief Create a Tensor from a DLPackTensorAllocator
346+
*
347+
* This function can be used together with TVMFFIEnvSetTensorAllocator
348+
* in the extra/c_env_api.h to create Tensor from the thread-local
349+
* environment allocator.
350+
*
351+
* \code
352+
*
353+
* ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc(
354+
* TVMFFIEnvGetTensorAllocator(), shape, dtype, device
355+
* );
356+
* \endcode
357+
*
358+
* \param allocator The DLPack allocator.
359+
* \param shape The shape of the Tensor.
360+
* \param dtype The data type of the Tensor.
361+
* \param device The device of the Tensor.
362+
* \return The created Tensor.
363+
*/
364+
static Tensor FromDLPackAlloc(DLPackTensorAllocator allocator, ffi::Shape shape, DLDataType dtype,
365+
DLDevice device) {
366+
if (allocator == nullptr) {
367+
TVM_FFI_THROW(RuntimeError)
368+
<< "FromDLPackAlloc: allocator is nullptr, "
369+
<< "likely because TVMFFIEnvSetTensorAllocator has not been called.";
370+
}
371+
DLTensor prototype;
372+
prototype.device = device;
373+
prototype.dtype = dtype;
374+
prototype.shape = const_cast<int64_t*>(shape.data());
375+
prototype.ndim = static_cast<int>(shape.size());
376+
prototype.strides = nullptr;
377+
prototype.byte_offset = 0;
378+
prototype.data = nullptr;
379+
DLManagedTensorVersioned* tensor = nullptr;
380+
// error context to be used to propagate error
381+
struct ErrorContext {
382+
std::string kind;
383+
std::string message;
384+
static void SetError(void* error_ctx, const char* kind, const char* message) {
385+
ErrorContext* error_context = static_cast<ErrorContext*>(error_ctx);
386+
error_context->kind = kind;
387+
error_context->message = message;
388+
}
389+
};
390+
ErrorContext error_context;
391+
int ret = (*allocator)(&prototype, &tensor, &error_context, ErrorContext::SetError);
392+
if (ret != 0) {
393+
throw ffi::Error(error_context.kind, error_context.message,
394+
TVMFFITraceback(__FILE__, __LINE__, __func__, 0));
395+
}
396+
return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(tensor));
397+
}
345398
/*!
346399
* \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API.
347400
* \param tensor The input DLPack managed tensor.

ffi/include/tvm/ffi/extra/c_env_api.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ typedef void* TVMFFIStreamHandle;
4646
* \param device_id The id of the device.
4747
* \param stream The stream to set.
4848
* \param opt_out_original_stream Output original stream if the address is not nullptr.
49-
* \note The stream is a weak reference that is cached/owned by the module.
5049
* \return 0 when success, nonzero when failure happens
5150
*/
52-
TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
53-
TVMFFIStreamHandle stream,
54-
TVMFFIStreamHandle* opt_out_original_stream);
51+
TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
52+
TVMFFIStreamHandle stream,
53+
TVMFFIStreamHandle* opt_out_original_stream);
5554

5655
/*!
5756
* \brief FFI function to get the current stream for a device
@@ -60,7 +59,29 @@ TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id
6059
* \param device_id The id of the device.
6160
* \return The current stream of the device.
6261
*/
63-
TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id);
62+
TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id);
63+
64+
/*!
65+
* \brief FFI function to set the current DLPack allocator in thread-local(TLS) context
66+
*
67+
* \param allocator The allocator to set.
68+
* \param write_to_global_context Whether to also set the allocator to the global context.
69+
* \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr.
70+
* \return 0 when success, nonzero when failure happens
71+
*/
72+
TVM_FFI_DLL int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator,
73+
int write_to_global_context,
74+
DLPackTensorAllocator* opt_out_original_allocator);
75+
76+
/*!
77+
* \brief FFI function get the current DLPack allocator stored in context.
78+
*
79+
* This function first queries the global context, and if not found,
80+
* queries the thread-local context.
81+
*
82+
* \return The current DLPack allocator.
83+
*/
84+
TVM_FFI_DLL DLPackTensorAllocator TVMFFIEnvGetTensorAllocator();
6485

6586
/*!
6687
* \brief Check if there are any signals raised in the surrounding env.

ffi/licenses/LICENSE.pytorch.txt

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
From PyTorch:
2+
3+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
4+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
5+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
6+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
7+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
8+
Copyright (c) 2011-2013 NYU (Clement Farabet)
9+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
10+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
11+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
12+
13+
From Caffe2:
14+
15+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
16+
17+
All contributions by Facebook:
18+
Copyright (c) 2016 Facebook Inc.
19+
20+
All contributions by Google:
21+
Copyright (c) 2015 Google Inc.
22+
All rights reserved.
23+
24+
All contributions by Yangqing Jia:
25+
Copyright (c) 2015 Yangqing Jia
26+
All rights reserved.
27+
28+
All contributions by Kakao Brain:
29+
Copyright 2019-2020 Kakao Brain
30+
31+
All contributions by Cruise LLC:
32+
Copyright (c) 2022 Cruise LLC.
33+
All rights reserved.
34+
35+
All contributions by Tri Dao:
36+
Copyright (c) 2024 Tri Dao.
37+
All rights reserved.
38+
39+
All contributions by Arm:
40+
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
41+
42+
All contributions from Caffe:
43+
Copyright(c) 2013, 2014, 2015, the respective contributors
44+
All rights reserved.
45+
46+
All other contributions:
47+
Copyright(c) 2015, 2016 the respective contributors
48+
All rights reserved.
49+
50+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
51+
copyright over their contributions to Caffe2. The project versioning records
52+
all such contribution and copyright details. If a contributor wants to further
53+
mark their specific copyright on a particular contribution, they should
54+
indicate their copyright solely in the commit message of the change when it is
55+
committed.
56+
57+
All rights reserved.
58+
59+
Redistribution and use in source and binary forms, with or without
60+
modification, are permitted provided that the following conditions are met:
61+
62+
1. Redistributions of source code must retain the above copyright
63+
notice, this list of conditions and the following disclaimer.
64+
65+
2. Redistributions in binary form must reproduce the above copyright
66+
notice, this list of conditions and the following disclaimer in the
67+
documentation and/or other materials provided with the distribution.
68+
69+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
70+
and IDIAP Research Institute nor the names of its contributors may be
71+
used to endorse or promote products derived from this software without
72+
specific prior written permission.
73+
74+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
75+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
76+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
77+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
78+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
80+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
81+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
82+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
83+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
84+
POSSIBILITY OF SUCH DAMAGE.

0 commit comments

Comments
 (0)