diff --git a/sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc b/sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc index bae11769da941..779823c854b6c 100644 --- a/sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc +++ b/sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc @@ -303,6 +303,23 @@ int main(int argc, char *argv[]) { } ---- +=== Standard C++ library specializations + +The `bfloat16` class has specializations of the `std::numeric_limits` and +`std::hash` standard C++ library classes. That is, an implementation must obey +the following statements: + + 1. A specialization of `std::hash` for `sycl::ext::oneapi::bfloat16` must exist + in the SYCL implementation that returns a unique value such that if two + instances of `sycl::ext::oneapi::bfloat16` are equal, in accordance with the + `==` operator, then their resulting hash values are also equal and + subsequently if two hash values are not equal, then their corresponding + instances are also not equal. + 2. A specialztion of `std::numeric_limits` for `sycl::ext::oneapi::bfloat16` + defining the arithmetic properties of the `sycl::ext::oneapi::bfloat16` in + accordance with the requirements specified by the standard C++ library + specification. + == Revision History [cols="5,15,15,70"] diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 69d8cf3a7f366..5ba60801c8c77 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -141,10 +141,20 @@ class bfloat16 { private: Bfloat16StorageT value; + // Private tag used to avoid constructor ambiguity. + struct private_tag { + explicit private_tag() = default; + }; + + constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {} + // Explicit conversion functions static float to_float(const Bfloat16StorageT &a); static Bfloat16StorageT from_float(const float &a); + // Friend traits. + friend std::numeric_limits; + // Friend classes for vector operations friend class sycl::vec; friend class sycl::vec; @@ -615,3 +625,80 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) { } // namespace ext::oneapi } // namespace _V1 } // namespace sycl + +// Specialization of some functions in namespace `std`. +namespace std { + +// Specialization of `std::hash`. +template <> struct hash { + size_t operator()(sycl::ext::oneapi::bfloat16 const &Key) const noexcept { + return hash{}(sycl::bit_cast(Key)); + } +}; + +// Specialization of `std::numeric_limits`. +template <> struct numeric_limits { + // All following values are calculated based on description of each + // function/value on https://en.cppreference.com/w/cpp/types/numeric_limits. + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_present; + static constexpr bool has_denorm_loss = false; + static constexpr bool tinyness_before = false; + static constexpr bool traps = false; + static constexpr int max_exponent10 = 35; + static constexpr int max_exponent = 127; + static constexpr int min_exponent10 = -37; + static constexpr int min_exponent = -126; + static constexpr int radix = 2; + static constexpr int max_digits10 = 4; + static constexpr int digits = 8; + static constexpr bool is_bounded = true; + static constexpr int digits10 = 2; + static constexpr bool is_modulo = false; + static constexpr bool is_iec559 = true; + static constexpr float_round_style round_style = round_to_nearest; + + static constexpr const sycl::ext::oneapi::bfloat16(min)() noexcept { + return {uint16_t(0x80), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16(max)() noexcept { + return {uint16_t(0x7f7f), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 lowest() noexcept { + return {uint16_t(0xff7f), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 epsilon() noexcept { + return {uint16_t(0x3c00), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 round_error() noexcept { + return {uint16_t(0x3f00), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 infinity() noexcept { + return {uint16_t(0x7f80), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 quiet_NaN() noexcept { + return {uint16_t(0x7fc0), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 signaling_NaN() noexcept { + return {uint16_t(0xff81), sycl::ext::oneapi::bfloat16::private_tag{}}; + } + + static constexpr const sycl::ext::oneapi::bfloat16 denorm_min() noexcept { + return {uint16_t(0x1), sycl::ext::oneapi::bfloat16::private_tag{}}; + } +}; + +} // namespace std diff --git a/sycl/unittests/Extensions/BFloat16.cpp b/sycl/unittests/Extensions/BFloat16.cpp index 6a2bd166ecc3a..8f5cd88cde39f 100644 --- a/sycl/unittests/Extensions/BFloat16.cpp +++ b/sycl/unittests/Extensions/BFloat16.cpp @@ -79,3 +79,71 @@ TEST(BFloat16, BF16ToFloat) { ASSERT_EQ(Res, bitsToFloatConv(std::string("11111111011111110000000000000000"))); } + +TEST(BFloat16, BF16Limits) { + namespace sycl_ext = sycl::ext::oneapi; + using Limit = std::numeric_limits; + constexpr float Log10_2 = 0.30103f; + auto constexpr_ceil = [](float Val) constexpr -> int { + return Val + (float(int(Val)) == Val ? 0.f : 1.f); + }; + + static_assert(Limit::is_specialized); + static_assert(Limit::is_signed); + static_assert(!Limit::is_integer); + static_assert(!Limit::is_exact); + static_assert(Limit::has_infinity); + static_assert(Limit::has_quiet_NaN); + static_assert(Limit::has_signaling_NaN); + static_assert(Limit::has_denorm == std::float_denorm_style::denorm_present); + static_assert(!Limit::has_denorm_loss); + static_assert(!Limit::tinyness_before); + static_assert(!Limit::traps); + static_assert(Limit::max_exponent10 == 35); + static_assert(Limit::max_exponent == 127); + static_assert(Limit::min_exponent10 == -37); + static_assert(Limit::min_exponent == -126); + static_assert(Limit::radix == 2); + static_assert(Limit::digits == 8); + static_assert(Limit::max_digits10 == + constexpr_ceil(float(Limit::digits) * Log10_2 + 1.0f)); + static_assert(Limit::is_bounded); + static_assert(Limit::digits10 == int(Limit::digits * Log10_2)); + static_assert(!Limit::is_modulo); + static_assert(Limit::is_iec559); + static_assert(Limit::round_style == std::float_round_style::round_to_nearest); + + EXPECT_TRUE(sycl_ext::experimental::isnan(Limit::quiet_NaN())); + EXPECT_TRUE(sycl_ext::experimental::isnan(Limit::signaling_NaN())); + // isinf does not exist for bfloat16 currently. + EXPECT_EQ(Limit::infinity(), + sycl::bit_cast(uint16_t(0xff << 7))); + EXPECT_EQ(Limit::round_error(), sycl_ext::bfloat16(0.5f)); + EXPECT_GT(sycl_ext::bfloat16{1.0f} + Limit::epsilon(), + sycl_ext::bfloat16{1.0f}); + + for (uint16_t Sign : {0, 1}) + for (uint16_t Exponent = 0; Exponent < 0xff; ++Exponent) + for (uint16_t Significand = 0; Significand < 0x7f; ++Significand) { + const auto Value = sycl::bit_cast( + uint16_t((Sign << 15) | (Exponent << 7) | Significand)); + + EXPECT_LE(Limit::lowest(), Value); + EXPECT_GE(Limit::max(), Value); + + // min() is the lowest normal number, so if Value is negative, 0 or a + // subnormal - the latter two being represented by a 0-exponent - min() + // must be strictly greater. + if (Sign || Exponent == 0x0) + EXPECT_GT(Limit::min(), Value); + else + EXPECT_LE(Limit::min(), Value); + + // denorm_min() is the lowest subnormal number, so if Value is negative + // or 0 denorm_min() must be strictly greater. + if (Sign || (Exponent == 0x0 && Significand == 0x0)) + EXPECT_GT(Limit::denorm_min(), Value); + else + EXPECT_LE(Limit::denorm_min(), Value); + } +}