Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ffi/include/tvm/ffi/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ struct AnyEqual {
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
const BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0;
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
}
return false;
}
Expand Down
36 changes: 25 additions & 11 deletions ffi/include/tvm/ffi/base_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,30 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
const char* it = data;
const char* end = it + size;
uint64_t result = 0;
for (; it + 8 <= end; it += 8) {
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
u.a[0] = it[0];
u.a[1] = it[1];
u.a[2] = it[2];
u.a[3] = it[3];
u.a[4] = it[4];
u.a[5] = it[5];
u.a[6] = it[6];
u.a[7] = it[7];
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
// if alignment requirement is met, directly use load
if (reinterpret_cast<uintptr_t>(it) % 8 == 0) {
for (; it + 8 <= end; it += 8) {
u.b = *reinterpret_cast<const uint64_t*>(it);
result = (result * kMultiplier + u.b) % kMod;
}
} else {
// unaligned version
for (; it + 8 <= end; it += 8) {
u.a[0] = it[0];
u.a[1] = it[1];
u.a[2] = it[2];
u.a[3] = it[3];
u.a[4] = it[4];
u.a[5] = it[5];
u.a[6] = it[6];
u.a[7] = it[7];
result = (result * kMultiplier + u.b) % kMod;
}
}
} else {
// need endian swap
for (; it + 8 <= end; it += 8) {
u.a[0] = it[7];
u.a[1] = it[6];
u.a[2] = it[5];
Expand All @@ -200,9 +213,10 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
u.a[5] = it[2];
u.a[6] = it[1];
u.a[7] = it[0];
result = (result * kMultiplier + u.b) % kMod;
}
result = (result * kMultiplier + u.b) % kMod;
}

if (it < end) {
u.b = 0;
uint8_t* a = u.a;
Expand Down
70 changes: 49 additions & 21 deletions ffi/include/tvm/ffi/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,34 @@ class Bytes : public ObjectRef {
* \return int zero if both char sequences compare equal. negative if this
* appear before other, positive otherwise.
*/
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;

for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
if (lhs[i] < rhs[i]) return -1;
if (lhs[i] > rhs[i]) return 1;
}
if (lhs_count < rhs_count) {
return -1;
} else if (lhs_count > rhs_count) {
return 1;
} else {
return 0;
}
}
/*!
* \brief Compare two char sequence for equality
*
* \param lhs Pointers to the char array to compare
* \param rhs Pointers to the char array to compare
* \param lhs_count Length of the char array to compare
* \param rhs_count Length of the char array to compare
*
* \return true if the two char sequences are equal, false otherwise.
*/
static bool memequal(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0);
}

private:
friend class String;
Expand Down Expand Up @@ -311,7 +338,18 @@ class String : public ObjectRef {
* before other, positive otherwise.
*/
int compare(const char* other) const {
return Bytes::memncmp(data(), other, size(), std::strlen(other));
const char* this_data = data();
size_t this_size = size();
for (size_t i = 0; i < this_size; ++i) {
// other is shorter than this
if (other[i] == '\0') return 1;
if (this_data[i] < other[i]) return -1;
if (this_data[i] > other[i]) return 1;
}
// other equals this
if (other[this_size] == '\0') return 0;
// other longer than this
return -1;
}

/*!
Expand Down Expand Up @@ -616,11 +654,17 @@ inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(
inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }

// Overload == operator
inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }
inline bool operator==(const String& lhs, const std::string& rhs) {
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
}

inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
inline bool operator==(const std::string& lhs, const String& rhs) {
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
}

inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }
inline bool operator==(const String& lhs, const String& rhs) {
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
}

inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }

Expand All @@ -641,22 +685,6 @@ inline std::ostream& operator<<(std::ostream& out, const String& input) {
out.write(input.data(), input.size());
return out;
}

inline int Bytes::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;

for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
if (lhs[i] < rhs[i]) return -1;
if (lhs[i] > rhs[i]) return 1;
}
if (lhs_count < rhs_count) {
return -1;
} else if (lhs_count > rhs_count) {
return 1;
} else {
return 0;
}
}
} // namespace ffi

// Expose to the tvm namespace for usability
Expand Down
34 changes: 34 additions & 0 deletions ffi/tests/cpp/test_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ TEST(String, Comparisons) {
EXPECT_EQ(m != s, mismatch != source);
}

TEST(String, Compare) {
// string compare const char*
String s{"hello"};
EXPECT_EQ(s.compare("hello"), 0);
EXPECT_EQ(s.compare(String("hello")), 0);

EXPECT_EQ(s.compare("hallo"), 1);
EXPECT_EQ(s.compare(String("hallo")), 1);
EXPECT_EQ(s.compare("hfllo"), -1);
EXPECT_EQ(s.compare(String("hfllo")), -1);
// s is longer
EXPECT_EQ(s.compare("hell"), 1);
EXPECT_EQ(s.compare(String("hell")), 1);
// s is shorter
EXPECT_EQ(s.compare("hello world"), -1);
EXPECT_EQ(s.compare(String("helloworld")), -1);
}

// Check '\0' handling
TEST(String, null_byte_handling) {
using namespace std;
Expand Down Expand Up @@ -369,4 +387,20 @@ TEST(String, CAPIAccessor) {
EXPECT_EQ(arr->size, 5);
EXPECT_EQ(std::string(arr->data, arr->size), "hello");
}

TEST(String, BytesHash) {
std::vector<int64_t> data1(10);
std::vector<int64_t> data2(11);
for (size_t i = 0; i < data1.size(); ++i) {
data1[i] = i;
}
char* data1_ptr = reinterpret_cast<char*>(data1.data());
char* data2_ptr = reinterpret_cast<char*>(data2.data()) + 1;
std::memcpy(data2_ptr, data1.data(), data1.size() * sizeof(int64_t));
// has of aligned and unaligned data should be the same
uint64_t hash1 = details::StableHashBytes(data1_ptr, data1.size() * sizeof(int64_t));
uint64_t hash2 = details::StableHashBytes(data2_ptr, data1.size() * sizeof(int64_t));
EXPECT_EQ(hash1, hash2);
}

} // namespace
Loading