Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libc/config/gpu/amdgpu/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
list(APPEND TARGET_LIBM_ENTRYPOINTS
# math.h C23 _Float16 entrypoints
libc.src.math.canonicalizef16
libc.src.math.cbrtf16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.coshf16
Expand Down
1 change: 1 addition & 0 deletions libc/config/gpu/nvptx/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
list(APPEND TARGET_LIBM_ENTRYPOINTS
# math.h C23 _Float16 entrypoints
libc.src.math.canonicalizef16
libc.src.math.cbrtf16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.coshf16
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/aarch64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
list(APPEND TARGET_LIBM_ENTRYPOINTS
# math.h C23 _Float16 entrypoints
libc.src.math.canonicalizef16
libc.src.math.cbrtf16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.cospif16
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.cosf16
libc.src.math.coshf16
libc.src.math.cospif16
libc.src.math.cbrtf16
libc.src.math.exp10f16
libc.src.math.exp10m1f16
libc.src.math.exp2f16
Expand Down
1 change: 1 addition & 0 deletions libc/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ add_math_entrypoint_object(iscanonicalf128)

add_math_entrypoint_object(cbrt)
add_math_entrypoint_object(cbrtf)
add_math_entrypoint_object(cbrtf16)

add_math_entrypoint_object(ceil)
add_math_entrypoint_object(ceilf)
Expand Down
21 changes: 21 additions & 0 deletions libc/src/math/cbrtf16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===-- Implementation header for cbrtf16 -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_CBRTF16_H
#define LLVM_LIBC_SRC_MATH_CBRTF16_H

#include "src/__support/macros/config.h" // LIBC_NAMESPACE_DECL
#include "src/__support/macros/properties/types.h" // float16

namespace LIBC_NAMESPACE_DECL {

float16 cbrtf16(float16 x);

} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC_MATH_CBRTF16_H
18 changes: 18 additions & 0 deletions libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4816,6 +4816,24 @@ add_entrypoint_object(
libc.src.__support.integer_literals
)

add_entrypoint_object(
cbrtf16
SRCS
cbrtf16.cpp
HDRS
../cbrtf16.h
DEPENDS
libc.hdr.fenv_macros
libc.src.__support.FPUtil.double_double
libc.src.__support.FPUtil.dyadic_float
libc.src.__support.FPUtil.fenv_impl
libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.multiply_add
libc.src.__support.FPUtil.polyeval
libc.src.__support.macros.optimization
libc.src.__support.integer_literals
)

add_entrypoint_object(
dmull
SRCS
Expand Down
2 changes: 1 addition & 1 deletion libc/src/math/generic/cbrtf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace {
// Look up table for 2^(i/3) for i = 0, 1, 2.
constexpr double CBRT2[3] = {1.0, 0x1.428a2f98d728bp0, 0x1.965fea53d6e3dp0};

// Degree-7 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
// Degree-6 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
// generated by Sollya with:
// > for i from 0 to 15 do {
// P = fpminimax(((1 + x)^(1/3) - 1)/x, 6, [|D...|], [i/16, (i + 1)/16]);
Expand Down
164 changes: 164 additions & 0 deletions libc/src/math/generic/cbrtf16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===-- Implementation of sqrtf16 function --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/cbrtf16.h"
#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY

namespace LIBC_NAMESPACE_DECL {

namespace {

// Look up table for 2^(i/3) for i = 0, 1, 2 in single precision
constexpr float CBRT2[3] = {0x1p0f, 0x1.428a3p0f, 0x1.965feap0f};

// Degree-4 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
// generated by Sollya with:
// > display=hexadecimal;
// for i from 0 to 15 do {
// P = fpminimax(((1 + x)^(1/3) - 1)/x, 4, [|SG...|], [i/16, (i + 1)/16]);
// print("{", coeff(P, 0), ",", coeff(P, 1), ",", coeff(P, 2), ",",
// coeff(P, 3), coeff(P, 4),"},");
// };
// Then (1 + x)^(1/3) ~ 1 + x * P(x).
// For example: for 0 <= x <= 1/8:
// P(x) = 0x1.555556p-2 + x * (-0x1.c71d38p-4 + x * (0x1.f9b95ap-5 + x *
// (-0x1.4ebe18p-5 + x * 0x1.9ca9d2p-6)))

constexpr float COEFFS[16][5] = {
{0x1.555556p-2f, -0x1.c71ea4p-4f, 0x1.faa5f2p-5f, -0x1.64febep-5f,
0x1.733a46p-5f},
{0x1.55554ep-2f, -0x1.c715f6p-4f, 0x1.f88a9ep-5f, -0x1.4456e8p-5f,
0x1.5b5ef2p-6f},
{0x1.555508p-2f, -0x1.c6f404p-4f, 0x1.f56b7ap-5f, -0x1.33cff8p-5f,
0x1.18f146p-6f},
{0x1.5553fcp-2f, -0x1.c69bacp-4f, 0x1.efed98p-5f, -0x1.204706p-5f,
0x1.c90976p-7f},
{0x1.55517p-2f, -0x1.c5f996p-4f, 0x1.e85932p-5f, -0x1.0c0c0ep-5f,
0x1.77c766p-7f},
{0x1.554c96p-2f, -0x1.c501d2p-4f, 0x1.df0fc4p-5f, -0x1.f067f2p-6f,
0x1.380ab8p-7f},
{0x1.55448cp-2f, -0x1.c3ab1ep-4f, 0x1.d45876p-5f, -0x1.ca3988p-6f,
0x1.04f38ap-7f},
{0x1.5538aap-2f, -0x1.c1f886p-4f, 0x1.c8b11p-5f, -0x1.a6a16cp-6f,
0x1.b847c2p-8f},
{0x1.55278ap-2f, -0x1.bfd538p-4f, 0x1.bbde6p-5f, -0x1.846a8cp-6f,
0x1.73bfcp-8f},
{0x1.5511dp-2f, -0x1.bd6c88p-4f, 0x1.af0a3ap-5f, -0x1.660852p-6f,
0x1.3dbe34p-8f},
{0x1.54f82ap-2f, -0x1.bada56p-4f, 0x1.a2aa0ep-5f, -0x1.4b8c2ap-6f,
0x1.13379cp-8f},
{0x1.54d512p-2f, -0x1.b7a936p-4f, 0x1.94b91ep-5f, -0x1.30792cp-6f,
0x1.d7883cp-9f},
{0x1.54a8d8p-2f, -0x1.b3fde2p-4f, 0x1.861aeep-5f, -0x1.169484p-6f,
0x1.92b4cap-9f},
{0x1.548126p-2f, -0x1.b0f4a8p-4f, 0x1.7af574p-5f, -0x1.04644ep-6f,
0x1.662fb6p-9f},
{0x1.544b9p-2f, -0x1.ad2124p-4f, 0x1.6dd75p-5f, -0x1.e0cbecp-7f,
0x1.387692p-9f},
{0x1.5422c6p-2f, -0x1.aa61bp-4f, 0x1.64f4bap-5f, -0x1.c742b2p-7f,
0x1.1cf15ap-9f},
};

} // anonymous namespace

LLVM_LIBC_FUNCTION(float16, cbrtf16, (float16 x)) {
using FPBits = fputil::FPBits<float16>;
using FloatBits = fputil::FPBits<float>;

FPBits x_bits(x);

uint16_t x_u = x_bits.uintval();
uint16_t x_abs = x_u & 0x7fff;
uint32_t sign_bit = (x_u >> 15) << FloatBits::EXP_LEN;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast (x_u >> 15) to uint32_t before left-shifting.


// cbrtf16(0) = 0, cbrtf16(NaN) = NaN
if (LIBC_UNLIKELY(x_abs == 0 || x_abs >= 0x7C00)) {
if (x_bits.is_signaling_nan()) {
fputil::raise_except(FE_INVALID);
return FPBits::quiet_nan().uintval();
}
return x;
}

float xf = static_cast<float>(x);
FloatBits xf_bits(xf);

unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
Copy link
Contributor

@lntue lntue Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting x_e to unsigned before dividing by 3 will give you completely wrong results for negative x_e.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah sorry! i completely missed that it could be negative as wel

unsigned out_e = (x_e / 3 + 127) | sign_bit;

unsigned shift_e = x_e % 3;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the error for $|x| &lt; 1$ is due to these


// Set x_m = 2^(x_e % 3) * (1 + mantissa)
uint32_t x_m = xf_bits.get_mantissa();

// Use the leading 4 bits for look up table
unsigned idx = static_cast<unsigned>(x_m >> (FloatBits::FRACTION_LEN - 4));

x_m |= static_cast<uint32_t>(FloatBits::EXP_BIAS) << FloatBits::FRACTION_LEN;

float x_reduced = FloatBits(x_m).get_val();
float dx = x_reduced - 1.0f;

float dx_sq = dx * dx;

// fputil::multiply_add(x, y, z) = x*y + z

// c0 = 1 + x * a0
float c0 = fputil::multiply_add(dx, COEFFS[idx][0], 1.0f);
// c1 = a1 + x * a2
float c1 = fputil::multiply_add(dx, COEFFS[idx][2], COEFFS[idx][1]);
// c2 = a3 + x * a4
float c2 = fputil::multiply_add(dx, COEFFS[idx][4], COEFFS[idx][3]);
// we save a multiply_add operation by decreasing the polynomial degree by 2
// i.e. using a degree-4 polynomial instead of degree 6.

float dx_4 = dx_sq * dx_sq;

// p0 = c0 + x^2 * c1
// p0 = (1 + x * a0) + x^2 * (a1 + x * a2)
// p0 = 1 + x * a0 + x^2 * a1 + x^3 * a2
float p0 = fputil::multiply_add(dx_sq, c1, c0);

// p1 = c2
// p1 = x * a4
float p1 = c2;

// r = p0 + x^4 * p1
// r = (1 + x * a0 + x^2 * a1 + x^3 * a2) + x^4 (x * a4)
// r = 1 + x * a0 + x^2 * a1 + x^3 * a2 + x^5 * a4
// r = 1 + x * (a0 + a1 * x + a2 * x^2 + a3 * x^3 + a4 * x^4)
// r = 1 + x * P(x)
float r = fputil::multiply_add(dx_4, p1, p0) * CBRT2[shift_e];

uint32_t r_m = FloatBits(r).get_mantissa();
// For float, mantissa is 23 bits (instead of 52 for double)
// Check if the output is exact. To be exact, the smallest 1-bit of the
// output has to be at least 2^-7 or higher. So we check the lowest 15 bits
// to see if they are within 2^(-23 + 3) errors from all zeros, then the
// result cube root is exact.
if (LIBC_UNLIKELY(((r_m + 4) & 0x7fff) <= 8)) {
if ((r_m & 0x7fff) <= 4)
r_m &= 0xffff'ffe0;
else
r_m = (r_m & 0xffff'ffe0) + 0x20; // Round up to next multiple of 0x20
fputil::clear_except_if_required(FE_INEXACT);
}

uint32_t r_bits =
r_m | (static_cast<uint32_t>(out_e) << FloatBits::FRACTION_LEN);

return static_cast<float16>(FloatBits(r_bits).get_val());
}

} // namespace LIBC_NAMESPACE_DECL
12 changes: 12 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2655,6 +2655,18 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
cbrtf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
cbrtf16_test.cpp
DEPENDS
libc.src.math.cbrtf16
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
dmull_test
NEED_MPFR
Expand Down
56 changes: 56 additions & 0 deletions libc/test/src/math/cbrtf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//===-- Unittests for cbrtf16 ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "hdr/math_macros.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/math/cbrtf16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

using LlvmLibcCbrtf16Test = LIBC_NAMESPACE::testing::FPTest<float16>;

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

// Range: [0, Inf];
static constexpr uint16_t POS_START = 0x0000U;
static constexpr uint16_t POS_STOP = 0x7c00U;

// Range: [-Inf, 0]
static constexpr uint16_t NEG_START = 0x8000U;
static constexpr uint16_t NEG_STOP = 0xfc00U;

TEST_F(LlvmLibcCbrtf16Test, PositiveRange) {
for (uint16_t v = POS_START; v <= POS_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
LIBC_NAMESPACE::cbrtf16(x), 0.5);
}
}

TEST_F(LlvmLibcCbrtf16Test, NegativeRange) {
for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
LIBC_NAMESPACE::cbrtf16(x), 0.5);
}
}

TEST_F(LlvmLibcCbrtf16Test, SpecialValues) {
constexpr uint16_t INPUTS[] = {
0x4a00, 0x4500, 0x4e00, 0x0c00, 0x4940,
};
for (uint16_t v : INPUTS) {
float16 x = FPBits(v).get_val();
mpfr::ForceRoundingMode r(mpfr::RoundingMode::Upward);
EXPECT_MPFR_MATCH(mpfr::Operation::Cbrt, x, LIBC_NAMESPACE::cbrtf16(x), 0.5,
mpfr::RoundingMode::Upward);
}

ASSERT_EQ(1, 1);
}
10 changes: 10 additions & 0 deletions libc/test/src/math/smoke/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5042,6 +5042,16 @@ add_fp_unittest(
libc.src.math.cbrt
)

add_fp_unittest(
cbrtf16_test
SUITE
libc-math-smoke-tests
SRCS
cbrtf16_test.cpp
DEPENDS
libc.src.math.cbrtf16
)

add_fp_unittest(
dmull_test
SUITE
Expand Down
33 changes: 33 additions & 0 deletions libc/test/src/math/smoke/cbrtf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===-- Unittests for cbrtf16 ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/cbrtf16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"

using LlvmLibcCbrtfTest = LIBC_NAMESPACE::testing::FPTest<float16>;

using LIBC_NAMESPACE::testing::tlog;

TEST_F(LlvmLibcCbrtfTest, SpecialNumbers) {
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::cbrtf16(aNaN));
EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::cbrtf16(inf));
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, LIBC_NAMESPACE::cbrtf16(neg_inf));
EXPECT_FP_EQ_ALL_ROUNDING(zero, LIBC_NAMESPACE::cbrtf16(zero));
EXPECT_FP_EQ_ALL_ROUNDING(neg_zero, LIBC_NAMESPACE::cbrtf16(neg_zero));
EXPECT_FP_EQ_ALL_ROUNDING(1.0f, LIBC_NAMESPACE::cbrtf16(1.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-1.0f, LIBC_NAMESPACE::cbrtf16(-1.0f));
EXPECT_FP_EQ_ALL_ROUNDING(2.0f, LIBC_NAMESPACE::cbrtf16(8.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-2.0f, LIBC_NAMESPACE::cbrtf16(-8.0f));
EXPECT_FP_EQ_ALL_ROUNDING(3.0f, LIBC_NAMESPACE::cbrtf16(27.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-3.0f, LIBC_NAMESPACE::cbrtf16(-27.0f));
EXPECT_FP_EQ_ALL_ROUNDING(5.0f, LIBC_NAMESPACE::cbrtf16(125.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-5.0f, LIBC_NAMESPACE::cbrtf16(-125.0f));
EXPECT_FP_EQ_ALL_ROUNDING(40.0f, LIBC_NAMESPACE::cbrtf16(0x1.f4p15));
EXPECT_FP_EQ_ALL_ROUNDING(-40.0f, LIBC_NAMESPACE::cbrtf16(-0x1.f4p15));
}
Loading