Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ffi/include/tvm/ffi/container/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class NDArrayObj : public Object, public DLTensor {
protected:
// backs up the shape of the NDArray
Optional<Shape> shape_data_;
Optional<Shape> stride_data_;

static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
Expand Down Expand Up @@ -184,9 +185,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj {
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->shape = const_cast<int64_t*>(shape.data());
this->strides = nullptr;
Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape));
this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
this->stride_data_ = std::move(strides);
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
}

Expand All @@ -202,9 +205,10 @@ class NDArrayObjFromDLPack : public NDArrayObj {
public:
explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
// set strides to nullptr if the tensor is contiguous.
if (IsContiguous(tensor->dl_tensor)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to check tensor->dl_tensor->strides == nullptr and act accordingly if needed

this->strides = nullptr;
if (tensor_->dl_tensor.strides == nullptr) {
Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
this->strides = const_cast<int64_t*>(strides.data());
this->stride_data_ = std::move(strides);
}
}

Expand Down
11 changes: 11 additions & 0 deletions ffi/include/tvm/ffi/container/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end
return p;
}

TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t* shape) {
int64_t* strides_data;
ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
int64_t stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
strides_data[i] = stride;
stride *= shape[i];
}
return strides;
}

} // namespace details

/*!
Expand Down
6 changes: 4 additions & 2 deletions ffi/tests/cpp/test_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ TEST(NDArray, DLPack) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
EXPECT_EQ(dlpack->dl_tensor.strides[0], 6);
EXPECT_EQ(dlpack->dl_tensor.strides[1], 3);
EXPECT_EQ(dlpack->dl_tensor.strides[2], 1);
EXPECT_EQ(nd.use_count(), 2);
{
NDArray nd2 = NDArray::FromDLPack(dlpack);
Expand All @@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
EXPECT_EQ(dlpack->dl_tensor.strides[0], 1);

EXPECT_EQ(nd.use_count(), 2);
{
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
strm->Write(data_byte_size);

if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU &&
tensor->strides == nullptr && tensor->byte_offset == 0) {
ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator {
Constant constant = Downcast<Constant>(arg);
runtime::NDArray ndarray = constant->data;
ICHECK_EQ(ndarray->device.device_type, kDLCPU);
ICHECK(ndarray->strides == nullptr);
ICHECK(ffi::IsContiguous(*ndarray.get()));
ICHECK_EQ(ndarray->byte_offset, 0);
ICHECK_EQ(ndarray->ndim, 1);
const int64_t* data = static_cast<const int64_t*>(ndarray->data);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/coreml/coreml_runtime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil];

ICHECK(data_in->strides == NULL);
ICHECK(ffi::IsContiguous(*data_in));
memcpy(dest.dataPointer, data_in->data, size);

NSString* nsKey = [NSString stringWithUTF8String:key.c_str()];
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
TensorRequisite res;
if (const_dl_tensor) {
ICHECK(const_dl_tensor->data);
ICHECK(const_dl_tensor->strides == nullptr);
ICHECK(ffi::IsContiguous(*const_dl_tensor));
auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data);
res = TensorRequisite::AsIs(mem, eid);
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/mps/conv.mm
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@
ICHECK_EQ(data->ndim, 4);
ICHECK_EQ(weight->ndim, 4);
ICHECK_EQ(output->ndim, 4);
ICHECK(output->strides == nullptr);
ICHECK(weight->strides == nullptr);
ICHECK(data->strides == nullptr);
ICHECK(ffi::IsContiguous(*output));
ICHECK(ffi::IsContiguous(*weight));
ICHECK(ffi::IsContiguous(*data));

ICHECK_EQ(data->shape[0], 1);
ICHECK_EQ(output->shape[0], 1);
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/mps/gemm.mm
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
ICHECK_EQ(A->ndim, 2);
ICHECK_EQ(B->ndim, 2);
ICHECK_EQ(C->ndim, 2);
ICHECK(C->strides == nullptr);
ICHECK(B->strides == nullptr);
ICHECK(A->strides == nullptr);
ICHECK(ffi::IsContiguous(*C));
ICHECK(ffi::IsContiguous(*B));
ICHECK(ffi::IsContiguous(*A));
ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class RandomEngine {
*/
void SampleUniform(DLTensor* data, float low, float high) {
ICHECK_GT(high, low) << "high must be bigger than low";
ICHECK(data->strides == nullptr);
ICHECK(ffi::IsContiguous(*data));

DLDataType dtype = data->dtype;
int64_t size = 1;
Expand All @@ -99,7 +99,7 @@ class RandomEngine {
*/
void SampleNormal(DLTensor* data, float loc, float scale) {
ICHECK_GT(scale, 0) << "standard deviation must be positive";
ICHECK(data->strides == nullptr);
ICHECK(ffi::IsContiguous(*data));

DLDataType dtype = data->dtype;
int64_t size = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/random/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
int64_t high = args[1].cast<int64_t>();
auto out = args[2].cast<DLTensor*>();
ICHECK_GT(high, low) << "high must be bigger than low";
ICHECK(out->strides == nullptr);
ICHECK(ffi::IsContiguous(*out));

DLDataType dtype = out->dtype;
int64_t size = 1;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/rocblas/rocblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
ICHECK_EQ(A->ndim, 2);
ICHECK_EQ(B->ndim, 2);
ICHECK_EQ(C->ndim, 2);
ICHECK(C->strides == nullptr);
ICHECK(B->strides == nullptr);
ICHECK(A->strides == nullptr);
ICHECK(ffi::IsContiguous(*C));
ICHECK(ffi::IsContiguous(*B));
ICHECK(ffi::IsContiguous(*A));
ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = interpreter_->typed_input_tensor<DType>(index);
DType* src = static_cast<DType*>(data_in->data);
ICHECK(data_in->strides == NULL);
ICHECK(ffi::IsContiguous(*data_in));
int64_t size = 1;
for (int64_t i = 0; i < data_in->ndim; ++i) {
size *= data_in->shape[i];
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
#define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_

#include <tvm/ffi/container/ndarray.h>

namespace tvm {
namespace ffi {
// Forward declare TVM Object to use `Object*` in RPC protocol.
Expand Down Expand Up @@ -255,7 +257,7 @@ struct RPCReference {
channel->Write(arr->ndim);
channel->Write(arr->dtype);
channel->WriteArray(arr->shape, arr->ndim);
if (arr->strides != nullptr) {
if (!ffi::IsContiguous(*arr)) {
channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride);
}
channel->Write(arr->byte_offset);
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/vm/rnn_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class RNNStateImpObj : public RNNStateObj {
_state.byte_offset = elem_offset * state->dtype.bits / 8;
_state.ndim = state->ndim - 2;
_state.shape = const_cast<int64_t*>(_state.shape + 2);
_state.strides = const_cast<int64_t*>(_state.strides + 2);
return _state;
}

Expand All @@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj {
_state.byte_offset = elem_offset * state->dtype.bits / 8;
_state.ndim = state->ndim - 1;
_state.shape = const_cast<int64_t*>(_state.shape + 1);
_state.strides = const_cast<int64_t*>(_state.strides + 1);
return _state;
}

Expand All @@ -428,7 +430,7 @@ class RNNStateImpObj : public RNNStateObj {
copy_src.ndim = 1;
copy_src.dtype = array->dtype;
copy_src.shape = array->shape;
copy_src.strides = nullptr;
copy_src.strides = array->strides;
copy_src.byte_offset = 0;
NDArray::CopyFromTo(&copy_src, &copy_dst);
};
Expand Down
Loading