Skip to content

[SYCL][Docs] Add std::hash and std::numeric_limits specialization for bfloat16 #19838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,23 @@ int main(int argc, char *argv[]) {
}
----

=== Standard C++ library specializations

The `bfloat16` class has specializations of the `std::numeric_limits` and
`std::hash` standard C++ library classes. That is, an implementation must obey
the following statements:

1. A specialization of `std::hash` for `sycl::ext::oneapi::bfloat16` must exist
in the SYCL implementation that returns a unique value such that if two
instances of `sycl::ext::oneapi::bfloat16` are equal, in accordance with the
`==` operator, then their resulting hash values are also equal and
subsequently if two hash values are not equal, then their corresponding
instances are also not equal.
2. A specialztion of `std::numeric_limits` for `sycl::ext::oneapi::bfloat16`
defining the arithmetic properties of the `sycl::ext::oneapi::bfloat16` in
accordance with the requirements specified by the standard C++ library
specification.

== Revision History

[cols="5,15,15,70"]
Expand Down
87 changes: 87 additions & 0 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,20 @@ class bfloat16 {
private:
Bfloat16StorageT value;

// Private tag used to avoid constructor ambiguity.
struct private_tag {
explicit private_tag() = default;
};

constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {}

// Explicit conversion functions
static float to_float(const Bfloat16StorageT &a);
static Bfloat16StorageT from_float(const float &a);

// Friend traits.
friend std::numeric_limits<bfloat16>;

// Friend classes for vector operations
friend class sycl::vec<bfloat16, 1>;
friend class sycl::vec<bfloat16, 2>;
Expand Down Expand Up @@ -615,3 +625,80 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) {
} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl

// Specialization of some functions in namespace `std`.
namespace std {
Comment on lines +629 to +630
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subjective, but I prefer

// in global ns
template <> struct std::{type}<types...> { ... };

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference, but the current style is what we use for hash in other headers, so I would like to keep it as is. Then we can change it at a global level if we want.


// Specialization of `std::hash<sycl::ext::oneapi::bfloat16>`.
template <> struct hash<sycl::ext::oneapi::bfloat16> {
size_t operator()(sycl::ext::oneapi::bfloat16 const &Key) const noexcept {
return hash<uint16_t>{}(sycl::bit_cast<uint16_t>(Key));
}
};

// Specialization of `std::numeric_limits<sycl::ext::oneapi::bfloat16>`.
template <> struct numeric_limits<sycl::ext::oneapi::bfloat16> {
// All following values are calculated based on description of each
// function/value on https://en.cppreference.com/w/cpp/types/numeric_limits.
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
static constexpr bool tinyness_before = false;
static constexpr bool traps = false;
static constexpr int max_exponent10 = 35;
static constexpr int max_exponent = 127;
static constexpr int min_exponent10 = -37;
static constexpr int min_exponent = -126;
static constexpr int radix = 2;
static constexpr int max_digits10 = 4;
static constexpr int digits = 8;
static constexpr bool is_bounded = true;
static constexpr int digits10 = 2;
static constexpr bool is_modulo = false;
static constexpr bool is_iec559 = true;
static constexpr float_round_style round_style = round_to_nearest;

static constexpr const sycl::ext::oneapi::bfloat16(min)() noexcept {
return {uint16_t(0x80), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16(max)() noexcept {
return {uint16_t(0x7f7f), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 lowest() noexcept {
return {uint16_t(0xff7f), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 epsilon() noexcept {
return {uint16_t(0x3c00), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 round_error() noexcept {
return {uint16_t(0x3f00), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 infinity() noexcept {
return {uint16_t(0x7f80), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 quiet_NaN() noexcept {
return {uint16_t(0x7fc0), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 signaling_NaN() noexcept {
return {uint16_t(0xff81), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 denorm_min() noexcept {
return {uint16_t(0x1), sycl::ext::oneapi::bfloat16::private_tag{}};
}
};

} // namespace std
68 changes: 68 additions & 0 deletions sycl/unittests/Extensions/BFloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,71 @@ TEST(BFloat16, BF16ToFloat) {
ASSERT_EQ(Res,
bitsToFloatConv(std::string("11111111011111110000000000000000")));
}

TEST(BFloat16, BF16Limits) {
namespace sycl_ext = sycl::ext::oneapi;
using Limit = std::numeric_limits<sycl_ext::bfloat16>;
constexpr float Log10_2 = 0.30103f;
auto constexpr_ceil = [](float Val) constexpr -> int {
return Val + (float(int(Val)) == Val ? 0.f : 1.f);
};

static_assert(Limit::is_specialized);
static_assert(Limit::is_signed);
static_assert(!Limit::is_integer);
static_assert(!Limit::is_exact);
static_assert(Limit::has_infinity);
static_assert(Limit::has_quiet_NaN);
static_assert(Limit::has_signaling_NaN);
static_assert(Limit::has_denorm == std::float_denorm_style::denorm_present);
static_assert(!Limit::has_denorm_loss);
static_assert(!Limit::tinyness_before);
static_assert(!Limit::traps);
static_assert(Limit::max_exponent10 == 35);
static_assert(Limit::max_exponent == 127);
static_assert(Limit::min_exponent10 == -37);
static_assert(Limit::min_exponent == -126);
static_assert(Limit::radix == 2);
static_assert(Limit::digits == 8);
static_assert(Limit::max_digits10 ==
constexpr_ceil(float(Limit::digits) * Log10_2 + 1.0f));
static_assert(Limit::is_bounded);
static_assert(Limit::digits10 == int(Limit::digits * Log10_2));
static_assert(!Limit::is_modulo);
static_assert(Limit::is_iec559);
static_assert(Limit::round_style == std::float_round_style::round_to_nearest);

EXPECT_TRUE(sycl_ext::experimental::isnan(Limit::quiet_NaN()));
EXPECT_TRUE(sycl_ext::experimental::isnan(Limit::signaling_NaN()));
// isinf does not exist for bfloat16 currently.
EXPECT_EQ(Limit::infinity(),
sycl::bit_cast<sycl_ext::bfloat16>(uint16_t(0xff << 7)));
EXPECT_EQ(Limit::round_error(), sycl_ext::bfloat16(0.5f));
EXPECT_GT(sycl_ext::bfloat16{1.0f} + Limit::epsilon(),
sycl_ext::bfloat16{1.0f});

for (uint16_t Sign : {0, 1})
for (uint16_t Exponent = 0; Exponent < 0xff; ++Exponent)
for (uint16_t Significand = 0; Significand < 0x7f; ++Significand) {
const auto Value = sycl::bit_cast<sycl_ext::bfloat16>(
uint16_t((Sign << 15) | (Exponent << 7) | Significand));

EXPECT_LE(Limit::lowest(), Value);
EXPECT_GE(Limit::max(), Value);

// min() is the lowest normal number, so if Value is negative, 0 or a
// subnormal - the latter two being represented by a 0-exponent - min()
// must be strictly greater.
if (Sign || Exponent == 0x0)
EXPECT_GT(Limit::min(), Value);
else
EXPECT_LE(Limit::min(), Value);

// denorm_min() is the lowest subnormal number, so if Value is negative
// or 0 denorm_min() must be strictly greater.
if (Sign || (Exponent == 0x0 && Significand == 0x0))
EXPECT_GT(Limit::denorm_min(), Value);
else
EXPECT_LE(Limit::denorm_min(), Value);
}
}