Skip to content

Commit 8e94ee5

Browse files
authored
[FFI] Improve string equal/hash handling (#18176)
This PR improves the string equal hash handling by improving some of the efficiencies.
1 parent efee448 commit 8e94ee5

File tree

4 files changed

+109
-33
lines changed

4 files changed

+109
-33
lines changed

ffi/include/tvm/ffi/any.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ struct AnyEqual {
576576
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
577577
const BytesObjBase* rhs_str =
578578
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
579-
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0;
579+
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
580580
}
581581
return false;
582582
}

ffi/include/tvm/ffi/base_details.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,30 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
181181
const char* it = data;
182182
const char* end = it + size;
183183
uint64_t result = 0;
184-
for (; it + 8 <= end; it += 8) {
185-
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
186-
u.a[0] = it[0];
187-
u.a[1] = it[1];
188-
u.a[2] = it[2];
189-
u.a[3] = it[3];
190-
u.a[4] = it[4];
191-
u.a[5] = it[5];
192-
u.a[6] = it[6];
193-
u.a[7] = it[7];
184+
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
185+
// if alignment requirement is met, directly use load
186+
if (reinterpret_cast<uintptr_t>(it) % 8 == 0) {
187+
for (; it + 8 <= end; it += 8) {
188+
u.b = *reinterpret_cast<const uint64_t*>(it);
189+
result = (result * kMultiplier + u.b) % kMod;
190+
}
194191
} else {
192+
// unaligned version
193+
for (; it + 8 <= end; it += 8) {
194+
u.a[0] = it[0];
195+
u.a[1] = it[1];
196+
u.a[2] = it[2];
197+
u.a[3] = it[3];
198+
u.a[4] = it[4];
199+
u.a[5] = it[5];
200+
u.a[6] = it[6];
201+
u.a[7] = it[7];
202+
result = (result * kMultiplier + u.b) % kMod;
203+
}
204+
}
205+
} else {
206+
// need endian swap
207+
for (; it + 8 <= end; it += 8) {
195208
u.a[0] = it[7];
196209
u.a[1] = it[6];
197210
u.a[2] = it[5];
@@ -200,9 +213,10 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
200213
u.a[5] = it[2];
201214
u.a[6] = it[1];
202215
u.a[7] = it[0];
216+
result = (result * kMultiplier + u.b) % kMod;
203217
}
204-
result = (result * kMultiplier + u.b) % kMod;
205218
}
219+
206220
if (it < end) {
207221
u.b = 0;
208222
uint8_t* a = u.a;

ffi/include/tvm/ffi/string.h

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,34 @@ class Bytes : public ObjectRef {
175175
* \return int zero if both char sequences compare equal. negative if this
176176
* appear before other, positive otherwise.
177177
*/
178-
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
178+
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
179+
if (lhs == rhs && lhs_count == rhs_count) return 0;
180+
181+
for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
182+
if (lhs[i] < rhs[i]) return -1;
183+
if (lhs[i] > rhs[i]) return 1;
184+
}
185+
if (lhs_count < rhs_count) {
186+
return -1;
187+
} else if (lhs_count > rhs_count) {
188+
return 1;
189+
} else {
190+
return 0;
191+
}
192+
}
193+
/*!
194+
* \brief Compare two char sequence for equality
195+
*
196+
* \param lhs Pointers to the char array to compare
197+
* \param rhs Pointers to the char array to compare
198+
* \param lhs_count Length of the char array to compare
199+
* \param rhs_count Length of the char array to compare
200+
*
201+
* \return true if the two char sequences are equal, false otherwise.
202+
*/
203+
static bool memequal(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
204+
return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0);
205+
}
179206

180207
private:
181208
friend class String;
@@ -311,7 +338,18 @@ class String : public ObjectRef {
311338
* before other, positive otherwise.
312339
*/
313340
int compare(const char* other) const {
314-
return Bytes::memncmp(data(), other, size(), std::strlen(other));
341+
const char* this_data = data();
342+
size_t this_size = size();
343+
for (size_t i = 0; i < this_size; ++i) {
344+
// other is shorter than this
345+
if (other[i] == '\0') return 1;
346+
if (this_data[i] < other[i]) return -1;
347+
if (this_data[i] > other[i]) return 1;
348+
}
349+
// other equals this
350+
if (other[this_size] == '\0') return 0;
351+
// other longer than this
352+
return -1;
315353
}
316354

317355
/*!
@@ -616,11 +654,17 @@ inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(
616654
inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }
617655

618656
// Overload == operator
619-
inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }
657+
inline bool operator==(const String& lhs, const std::string& rhs) {
658+
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
659+
}
620660

621-
inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
661+
inline bool operator==(const std::string& lhs, const String& rhs) {
662+
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
663+
}
622664

623-
inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }
665+
inline bool operator==(const String& lhs, const String& rhs) {
666+
return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size());
667+
}
624668

625669
inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }
626670

@@ -641,22 +685,6 @@ inline std::ostream& operator<<(std::ostream& out, const String& input) {
641685
out.write(input.data(), input.size());
642686
return out;
643687
}
644-
645-
inline int Bytes::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
646-
if (lhs == rhs && lhs_count == rhs_count) return 0;
647-
648-
for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
649-
if (lhs[i] < rhs[i]) return -1;
650-
if (lhs[i] > rhs[i]) return 1;
651-
}
652-
if (lhs_count < rhs_count) {
653-
return -1;
654-
} else if (lhs_count > rhs_count) {
655-
return 1;
656-
} else {
657-
return 0;
658-
}
659-
}
660688
} // namespace ffi
661689

662690
// Expose to the tvm namespace for usability

ffi/tests/cpp/test_string.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,24 @@ TEST(String, Comparisons) {
9595
EXPECT_EQ(m != s, mismatch != source);
9696
}
9797

98+
TEST(String, Compare) {
99+
// string compare const char*
100+
String s{"hello"};
101+
EXPECT_EQ(s.compare("hello"), 0);
102+
EXPECT_EQ(s.compare(String("hello")), 0);
103+
104+
EXPECT_EQ(s.compare("hallo"), 1);
105+
EXPECT_EQ(s.compare(String("hallo")), 1);
106+
EXPECT_EQ(s.compare("hfllo"), -1);
107+
EXPECT_EQ(s.compare(String("hfllo")), -1);
108+
// s is longer
109+
EXPECT_EQ(s.compare("hell"), 1);
110+
EXPECT_EQ(s.compare(String("hell")), 1);
111+
// s is shorter
112+
EXPECT_EQ(s.compare("hello world"), -1);
113+
EXPECT_EQ(s.compare(String("helloworld")), -1);
114+
}
115+
98116
// Check '\0' handling
99117
TEST(String, null_byte_handling) {
100118
using namespace std;
@@ -369,4 +387,20 @@ TEST(String, CAPIAccessor) {
369387
EXPECT_EQ(arr->size, 5);
370388
EXPECT_EQ(std::string(arr->data, arr->size), "hello");
371389
}
390+
391+
TEST(String, BytesHash) {
392+
std::vector<int64_t> data1(10);
393+
std::vector<int64_t> data2(11);
394+
for (size_t i = 0; i < data1.size(); ++i) {
395+
data1[i] = i;
396+
}
397+
char* data1_ptr = reinterpret_cast<char*>(data1.data());
398+
char* data2_ptr = reinterpret_cast<char*>(data2.data()) + 1;
399+
std::memcpy(data2_ptr, data1.data(), data1.size() * sizeof(int64_t));
400+
// has of aligned and unaligned data should be the same
401+
uint64_t hash1 = details::StableHashBytes(data1_ptr, data1.size() * sizeof(int64_t));
402+
uint64_t hash2 = details::StableHashBytes(data2_ptr, data1.size() * sizeof(int64_t));
403+
EXPECT_EQ(hash1, hash2);
404+
}
405+
372406
} // namespace

0 commit comments

Comments
 (0)