Skip to content

Commit 2028973

Browse files
committed
[FFI][REFACTOR] Hide StringObj/BytesObj into details
This PR hides StringObj/BytesObj into details and bring implementations to directly focus on the String/Bytes. This change will prepare us for future changes such as SmallStr support. Also moves more ObjectRef into Any in RPC.
1 parent dafdafe commit 2028973

File tree

37 files changed

+198
-182
lines changed

37 files changed

+198
-182
lines changed

ffi/include/tvm/ffi/any.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,8 @@ struct AnyHash {
546546
uint64_t val_hash = [&]() -> uint64_t {
547547
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
548548
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
549-
const BytesObjBase* src_str =
550-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
549+
const details::BytesObjBase* src_str =
550+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
551551
return details::StableHashBytes(src_str->data, src_str->size);
552552
} else {
553553
return src.data_.v_uint64;
@@ -572,10 +572,10 @@ struct AnyEqual {
572572
// specialy handle string hash
573573
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
574574
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
575-
const BytesObjBase* lhs_str =
576-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
577-
const BytesObjBase* rhs_str =
578-
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
575+
const details::BytesObjBase* lhs_str =
576+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
577+
const details::BytesObjBase* rhs_str =
578+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
579579
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
580580
}
581581
return false;

ffi/include/tvm/ffi/string.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
namespace tvm {
4848
namespace ffi {
49-
49+
namespace details {
5050
/*! \brief Base class for bytes and string. */
5151
class BytesObjBase : public Object, public TVMFFIByteArray {};
5252

@@ -73,8 +73,6 @@ class StringObj : public BytesObjBase {
7373
TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
7474
};
7575

76-
namespace details {
77-
7876
// String moved from std::string
7977
// without having to trigger a copy
8078
template <typename Base>
@@ -115,21 +113,21 @@ class Bytes : public ObjectRef {
115113
* \param other a char array.
116114
*/
117115
Bytes(const char* data, size_t size) // NOLINT(*)
118-
: ObjectRef(details::MakeInplaceBytes<BytesObj>(data, size)) {}
116+
: ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {}
119117
/*!
120118
* \brief constructor from char [N]
121119
*
122120
* \param other a char array.
123121
*/
124122
Bytes(TVMFFIByteArray bytes) // NOLINT(*)
125-
: ObjectRef(details::MakeInplaceBytes<BytesObj>(bytes.data, bytes.size)) {}
123+
: ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data, bytes.size)) {}
126124
/*!
127125
* \brief constructor from char [N]
128126
*
129127
* \param other a char array.
130128
*/
131129
Bytes(std::string other) // NOLINT(*)
132-
: ObjectRef(make_object<details::BytesObjStdImpl<BytesObj>>(std::move(other))) {}
130+
: ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other))) {}
133131
/*!
134132
* \brief Swap this String with another string
135133
* \param other The other string
@@ -163,7 +161,7 @@ class Bytes : public ObjectRef {
163161
*/
164162
operator std::string() const { return std::string{get()->data, size()}; }
165163

166-
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, BytesObj);
164+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, details::BytesObj);
167165

168166
/*!
169167
* \brief Compare two char sequence
@@ -245,7 +243,7 @@ class String : public ObjectRef {
245243
*/
246244
template <size_t N>
247245
String(const char other[N]) // NOLINT(*)
248-
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, N)) {}
246+
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, N)) {}
249247

250248
/*!
251249
* \brief constructor
@@ -258,37 +256,37 @@ class String : public ObjectRef {
258256
* \param other a char array.
259257
*/
260258
String(const char* other) // NOLINT(*)
261-
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, std::strlen(other))) {}
259+
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, std::strlen(other))) {}
262260

263261
/*!
264262
* \brief constructor from raw string
265263
*
266264
* \param other a char array.
267265
*/
268266
String(const char* other, size_t size) // NOLINT(*)
269-
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, size)) {}
267+
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, size)) {}
270268

271269
/*!
272270
* \brief Construct a new string object
273271
* \param other The std::string object to be copied
274272
*/
275273
String(const std::string& other) // NOLINT(*)
276-
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data(), other.size())) {}
274+
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data(), other.size())) {}
277275

278276
/*!
279277
* \brief Construct a new string object
280278
* \param other The std::string object to be moved
281279
*/
282280
String(std::string&& other) // NOLINT(*)
283-
: ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}
281+
: ObjectRef(make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other))) {}
284282

285283
/*!
286284
* \brief constructor from TVMFFIByteArray
287285
*
288286
* \param other a TVMFFIByteArray.
289287
*/
290288
explicit String(TVMFFIByteArray other)
291-
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data, other.size)) {}
289+
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data, other.size)) {}
292290

293291
/*!
294292
* \brief Swap this String with another string
@@ -423,7 +421,7 @@ class String : public ObjectRef {
423421
*/
424422
operator std::string() const { return std::string{get()->data, size()}; }
425423

426-
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
424+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, details::StringObj);
427425

428426
private:
429427
/*!

ffi/src/ffi/extra/structural_equal.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ class StructEqualHandler {
6262
case TypeIndex::kTVMFFIStr:
6363
case TypeIndex::kTVMFFIBytes: {
6464
// compare bytes
65-
const BytesObjBase* lhs_str =
66-
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
67-
const BytesObjBase* rhs_str =
68-
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
65+
const details::BytesObjBase* lhs_str =
66+
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
67+
const details::BytesObjBase* rhs_str =
68+
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
6969
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0;
7070
}
7171
case TypeIndex::kTVMFFIArray: {

ffi/src/ffi/extra/structural_hash.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ class StructuralHashHandler {
6464
case TypeIndex::kTVMFFIStr:
6565
case TypeIndex::kTVMFFIBytes: {
6666
// return same hash as AnyHash
67-
const BytesObjBase* src_str =
68-
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
67+
const details::BytesObjBase* src_str =
68+
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
6969
return details::StableHashCombine(src_data->type_index,
7070
details::StableHashBytes(src_str->data, src_str->size));
7171
}
@@ -196,8 +196,8 @@ class StructuralHashHandler {
196196
} else {
197197
if (src_data->type_index == TypeIndex::kTVMFFIStr ||
198198
src_data->type_index == TypeIndex::kTVMFFIBytes) {
199-
const BytesObjBase* src_str =
200-
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
199+
const details::BytesObjBase* src_str =
200+
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
201201
// return same hash as AnyHash
202202
return details::StableHashCombine(src_data->type_index,
203203
details::StableHashBytes(src_str->data, src_str->size));

include/tvm/ir/transform.h

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -242,26 +242,40 @@ class PassContext : public ObjectRef {
242242
template <typename ValueType>
243243
static int32_t RegisterConfigOption(const char* key) {
244244
// NOTE: we could further update the function later.
245-
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
246-
auto* reflection = ReflectionVTable::Global();
247-
auto type_key = ffi::TypeIndexToTypeKey(tindex);
248-
249-
auto legalization = [=](ffi::Any value) -> ffi::Any {
250-
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
251-
return reflection->CreateObject(type_key, opt_map.value());
252-
} else {
245+
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
246+
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
247+
auto* reflection = ReflectionVTable::Global();
248+
auto type_key = ffi::TypeIndexToTypeKey(tindex);
249+
auto legalization = [=](ffi::Any value) -> ffi::Any {
250+
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
251+
return reflection->CreateObject(type_key, opt_map.value());
252+
} else {
253+
auto opt_val = value.try_cast<ValueType>();
254+
if (!opt_val.has_value()) {
255+
TVM_FFI_THROW(AttributeError)
256+
<< "Expect config " << key << " to have type " << type_key << ", but instead get "
257+
<< ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
258+
}
259+
return *opt_val;
260+
}
261+
};
262+
RegisterConfigOption(key, type_key, legalization);
263+
} else {
264+
// non-object type, do not support implicit conversion from map
265+
std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
266+
auto legalization = [=](ffi::Any value) -> ffi::Any {
253267
auto opt_val = value.try_cast<ValueType>();
254268
if (!opt_val.has_value()) {
255269
TVM_FFI_THROW(AttributeError)
256-
<< "Expect config " << key << " to have type " << type_key << ", but instead get "
270+
<< "Expect config " << key << " to have type " << type_str << ", but instead get "
257271
<< ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
272+
} else {
273+
return *opt_val;
258274
}
259-
return value;
260-
}
261-
};
262-
263-
RegisterConfigOption(key, tindex, legalization);
264-
return tindex;
275+
};
276+
RegisterConfigOption(key, type_str, legalization);
277+
}
278+
return 0;
265279
}
266280

267281
// accessor.
@@ -274,7 +288,7 @@ class PassContext : public ObjectRef {
274288
// The exit of a pass context scope.
275289
TVM_DLL void ExitWithScope();
276290
// Register configuration key value type.
277-
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
291+
TVM_DLL static void RegisterConfigOption(const char* key, String value_type_str,
278292
std::function<ffi::Any(ffi::Any)> legalization);
279293

280294
// Classes to get the Python `with` like syntax.

include/tvm/runtime/profiling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ class MetricCollectorNode : public Object {
315315
/*! \brief Stop collecting metrics.
316316
* \param obj The object created by the corresponding `Start` call.
317317
* \returns A set of metric names and the associated values. Values must be
318-
* one of DurationNode, PercentNode, CountNode, or StringObj.
318+
* one of DurationNode, PercentNode, CountNode, or String.
319319
*/
320320
virtual Map<String, ffi::Any> Stop(ffi::ObjectRef obj) = 0;
321321

include/tvm/target/target_kind.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,27 @@ struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
287287
using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
288288

289289
ValueTypeInfo operator()() const {
290-
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
291290
ValueTypeInfo info;
292-
info.type_index = tindex;
293-
info.type_key = runtime::Object::TypeIndex2Key(tindex);
294291
info.key = nullptr;
295292
info.val = nullptr;
296-
return info;
293+
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
294+
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
295+
info.type_index = tindex;
296+
info.type_key = runtime::Object::TypeIndex2Key(tindex);
297+
return info;
298+
} else if constexpr (std::is_same_v<ValueType, String>) {
299+
// special handle string since it can be backed by multiple types.
300+
info.type_index = ffi::TypeIndex::kTVMFFIStr;
301+
info.type_key = ffi::TypeTraits<ValueType>::TypeStr();
302+
return info;
303+
} else {
304+
// TODO(tqchen) consider upgrade to leverage any system to support union type
305+
constexpr int32_t tindex = ffi::TypeToFieldStaticTypeIndex<ValueType>::value;
306+
static_assert(tindex != ffi::TypeIndex::kTVMFFIAny, "Do not support union type for now");
307+
info.type_index = tindex;
308+
info.type_key = runtime::Object::TypeIndex2Key(tindex);
309+
return info;
310+
}
297311
}
298312
};
299313

python/tvm/exec/disco_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def _str_func(x: str):
4848

4949

5050
@register_func("tests.disco.str_obj", override=True)
51-
def _str_obj_func(x: String):
52-
assert isinstance(x, String)
51+
def _str_obj_func(x: str):
52+
assert isinstance(x, str)
5353
return String(x + "_suffix")
5454

5555

python/tvm/ffi/cython/function.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except
8787
out[i].type_index = kTVMFFINDArray
8888
out[i].v_ptr = (<NDArray>arg).chandle
8989
temp_args.append(arg)
90-
elif isinstance(arg, PyNativeObject):
90+
elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
9191
arg = arg.__tvm_ffi_object__
9292
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
9393
out[i].v_ptr = (<Object>arg).chandle

python/tvm/ffi/cython/string.pxi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class String(str, PyNativeObject):
4040
"""
4141
def __new__(cls, value):
4242
val = str.__new__(cls, value)
43-
val.__init_tvm_ffi_object_by_constructor__(_STR_CONSTRUCTOR, value)
43+
val.__tvm_ffi_object__ = None
4444
return val
4545

4646
# pylint: disable=no-self-argument
@@ -65,7 +65,7 @@ class Bytes(bytes, PyNativeObject):
6565
"""
6666
def __new__(cls, value):
6767
val = bytes.__new__(cls, value)
68-
val.__init_tvm_ffi_object_by_constructor__(_BYTES_CONSTRUCTOR, value)
68+
val.__tvm_ffi_object__ = None
6969
return val
7070

7171
# pylint: disable=no-self-argument

0 commit comments

Comments
 (0)