libm: Add fmaf16

This commit is contained in:
Trevor Gross
2026-04-20 10:15:12 +00:00
parent f8dda89274
commit fe9b9956f8
13 changed files with 357 additions and 16 deletions
@@ -1081,6 +1081,16 @@ pub fn defined_in_compiler_builtins(self) -> bool {
],
scope: OpScope::LibmPublic,
},
NestedOp {
// `(f16, f16, f16) -> f16`
rust_sig: Signature {
args: &[Ty::F16, Ty::F16, Ty::F16],
returns: &[Ty::F16],
},
c_sig: None,
fn_list: &["fmaf16"],
scope: OpScope::LibmPublic,
},
NestedOp {
// `(f32, f32, f32) -> f32`
rust_sig: Signature {
@@ -365,6 +365,13 @@
],
"type": "f128"
},
"fmaf16": {
"sources": [
"libm/src/math/fmaf16.rs",
"libm/src/math/generic/fma.rs"
],
"type": "f16"
},
"fmax": {
"sources": [
"libm/src/math/fmin_fmax.rs",
@@ -54,6 +54,7 @@ floorf16
fma
fmaf
fmaf128
fmaf16
fmax
fmaxf
fmaxf128
@@ -132,6 +132,7 @@ fn [< icount_bench_ $fn_name >](cases: Vec<OpRustArgs<op::$fn_name::Routine>>) {
icount_bench_fma,
icount_bench_fmaf,
icount_bench_fmaf128,
icount_bench_fmaf16,
icount_bench_fmax,
icount_bench_fmaxf,
icount_bench_fmaxf128,
@@ -6,6 +6,8 @@
//!
//! This is useful for adding regression tests or expected failures.
#[cfg(f16_enabled)]
use libm::hf16;
#[cfg(f128_enabled)]
use libm::hf128;
use libm::{hf32, hf64};
@@ -943,6 +945,55 @@ fn floorf16_cases() -> Vec<TestCase<op::floorf16::Routine>> {
cases![]
}
#[cfg(f16_enabled)]
fn fmaf16_cases() -> Vec<TestCase<op::fmaf16::Routine>> {
cases![
// Subnormal result
((hf16!("0x1p-11"), hf16!("0x1p-11"), 0.0), None,),
((hf16!("0x1p-24"), hf16!("0x1p-24"), 0.0), None,),
// Failed during extensive tests
(
(
hf16!("-0x1.c4p-12"),
hf16!("0x1.22p-14"),
hf16!("-0x1.f4p-15"),
),
hf16!("-0x1.f48p-15")
),
// Examples from https://github.com/llvm/llvm-project/issues/128450
(
(
hf16!("0x1.400p+8"),
hf16!("0x1.008p+7"),
hf16!("0x1.000p-24"),
),
hf16!("0x1.40cp+15")
),
(
(hf16!("0x1.eb8p-12"), hf16!("0x1.9p-11"), hf16!("-0x1p-11"),),
None
),
// Previous failures during testing
((-569.0, -4.89, 65470.0), None),
((-998.0, 0.02596, -998.0), None),
((6e-8, 6e-8, -6.104e-5), None),
((6e-8, 65300.0, 9.5e-7), None),
((-569.0, -4.89, -0.0417), None),
((6444.0, 0.003443, 0.003443), None),
((6e-8, 6e-8, 6e-8), None),
((6e-8, -1.0, 6e-8), None),
((1.001, 65500.0, 6e-8), None),
((1.001, 65500.0, 65500.0), None),
((1.002, 65470.0, 0.0), None),
((1.002, 65470.0, 65500.0), None),
((0.0002216, -8.464e-5, 4.4e-5), None),
((-56700.0, -2.082, -61120.0), None),
((-475.3, -475.3, 60450.0), None),
((-61120.0, -3.969, 20320.0), None),
((6e-8, -6e-8, 0.0), None),
]
}
fn fmaf_cases() -> Vec<TestCase<op::fmaf::Routine>> {
cases![
// Known rounding error for some implementations (notably MinGW)
@@ -357,7 +357,7 @@ fn run(this: &mut Self::MpTy, input: Self::RustArgs) -> Self::RustRet {
expm1 | expm1f => exp_m1,
fabs | fabsf => abs,
fdim | fdimf | fdimf16 | fdimf128 => positive_diff,
fma | fmaf | fmaf128 => mul_add,
fmaf16 | fma | fmaf | fmaf128 => mul_add,
fmax | fmaxf | fmaxf16 | fmaxf128 |
fmaximum_num | fmaximum_numf | fmaximum_numf16 | fmaximum_numf128 => max,
fmin | fminf | fminf16 | fminf128 |
@@ -549,6 +549,8 @@ impl MaybeOverride<(f64, i32)> for SpecialCase {}
#[cfg(f128_enabled)]
impl MaybeOverride<(f128, i32)> for SpecialCase {}
#[cfg(f16_enabled)]
impl MaybeOverride<(f16, f16, f16)> for SpecialCase {}
impl MaybeOverride<(f32, f32, f32)> for SpecialCase {}
impl MaybeOverride<(f64, f64, f64)> for SpecialCase {}
#[cfg(f128_enabled)]
@@ -195,6 +195,7 @@ pub fn $func($($arg: $arg_typ),*) -> ($($ret_typ),*) {
(fn fabs(x: f16) -> (f16); => fabsf16);
(fn fdim(x: f16, y: f16) -> (f16); => fdimf16);
(fn floor(x: f16) -> (f16); => floorf16);
(fn fma(x: f16, y: f16, z: f16) -> (f16); => fmaf16);
(fn fmax(x: f16, y: f16) -> (f16); => fmaxf16);
(fn fmaximum_num(x: f16, y: f16) -> (f16); => fmaximum_numf16);
(fn fmaximumf16(x: f16, y: f16) -> (f16); => fmaximumf16);
+11 -7
View File
@@ -4,13 +4,7 @@
use super::generic;
use crate::support::Round;
// Placeholder so we can have `fmaf16` in the `Float` trait.
#[allow(unused)]
#[cfg(f16_enabled)]
#[cfg_attr(assert_no_panic, no_panic::no_panic)]
pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
unimplemented!()
}
/* See `fmaf16.rs` for that implementation */
/// Floating multiply add (f32)
///
@@ -78,8 +72,12 @@ macro_rules! cases {
(-1.0, -1.0, -1.0, 0.0),
// Roundtrip
(<$f>::MAX, 1.0, 0.0, <$f>::MAX),
(<$f>::MAX, <$f>::MAX, 1.0, <$f>::INFINITY),
(<$f>::MAX, 1.0, -<$f>::MAX, 0.0),
(-<$f>::MAX, 1.0, <$f>::MAX, 0.0),
(<$f>::MIN_POSITIVE_NORMAL, 1.0, -<$f>::MIN_POSITIVE_NORMAL, 0.0),
(-<$f>::MIN_POSITIVE_NORMAL, 1.0, <$f>::MIN_POSITIVE_NORMAL, 0.0),
(<$f>::MIN_POSITIVE_SUBNORMAL, 1.0, -<$f>::MIN_POSITIVE_SUBNORMAL, 0.0),
(-<$f>::MIN_POSITIVE_SUBNORMAL, 1.0, <$f>::MIN_POSITIVE_SUBNORMAL, 0.0),
(<$f>::MAX, 1.0, -<$f>::MAX, 0.0),
@@ -110,6 +108,12 @@ fn check<F: Float>(f: fn(F, F, F) -> F, cases: &[(F, F, F, F)]) {
}
}
#[test]
#[cfg(f16_enabled)]
fn check_f16() {
check::<f16>(super::super::fmaf16, &cases!(f16));
}
#[test]
fn check_f32() {
check::<f32>(fmaf, &cases!(f32));
@@ -0,0 +1,258 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0
* origin: original implementation, 2026 (TG) */
use crate::support::{CastFrom, Float, Int, unbounded_shr_u64};
/// We use a a U21.43 fixed-point representation when needed.
const FIXED_FRAC_BITS: u32 = 43;
/// Floating multiply add (f16)
///
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(assert_no_panic, no_panic::no_panic)]
pub fn fmaf16(x: f16, y: f16, z: f16) -> f16 {
let ix = x.to_bits() & !f16::SIGN_MASK;
let iy = y.to_bits() & !f16::SIGN_MASK;
let iz = z.to_bits() & !f16::SIGN_MASK;
let xneg = x.is_sign_negative();
let yneg = y.is_sign_negative();
let zneg = z.is_sign_negative();
let mneg = xneg ^ yneg;
if ix == 0 || ix >= f16::EXP_MASK || iy == 0 || iy >= f16::EXP_MASK {
// Value will overflow, defer to non-fused operations.
return x * y + z;
}
if iz == 0 {
// Empty add component means we only need to multiply.
return x * y;
}
if iz >= f16::EXP_MASK {
// `z` is NaN or infinity, which sets the result.
return z;
}
let mut xexp = x.ex();
let mut yexp = y.ex();
let mut zexp = z.ex();
let mut xsig = ix & f16::SIG_MASK;
let mut ysig = iy & f16::SIG_MASK;
let mut zsig = iz & f16::SIG_MASK;
// If not subnormal, set the implicit bit
if xexp != 0 {
xsig |= f16::IMPLICIT_BIT;
}
if yexp != 0 {
ysig |= f16::IMPLICIT_BIT;
}
if zexp != 0 {
zsig |= f16::IMPLICIT_BIT;
}
// A biased exponent of 1 (min normal) and 0 (subnormal) have the same real exponent, so
// adjust for this. Bias is now 14 rather than 15.
xexp = xexp.saturating_sub(1);
yexp = yexp.saturating_sub(1);
zexp = zexp.saturating_sub(1);
let adjbias = f16::EXP_BIAS - 1;
// Exponent after multiplication. Bias doubles to 28.
let mexp = xexp + yexp;
let mbias = adjbias * 2;
// Exit now if we know the result will overflow. We need to keep one beyond the infinite
// exponent in case the addition rounds down to a finite number.
//
// Note that `EXP_MAX` (i.e. max finite) represents infinity here because our values are
// acting with a bias of 14.
let inf_exp = mbias + f16::EXP_MAX.unsigned();
if mexp > inf_exp + 1 {
if mneg {
return f16::NEG_INFINITY;
} else {
return f16::INFINITY;
}
}
// Multiplication moves the explicit 1 from the 11th bit to the 22nd bit.
let m = u32::from(xsig) * u32::from(ysig);
let mut m64 = u64::from(m);
// The entire dynamic range of an `f16` fits into a `u64`. Shift based on the exponent to
// create a U21.43 fixed-point value. At the maximum exponent, there are five zeros before
// the explicit leading 1 (intentional so this truncates to the final repr).
if let Some(mshift) = mexp.checked_sub(5) {
debug_assert_eq!(
unbounded_shr_u64(m64, 64 - mshift),
0,
"data shifted out {m} {mshift}"
);
m64 <<= mshift;
} else {
// The lower few bits here would be on the order of 2^-43, which is too small to show up
// in a result significand. Just squash them to a sticky bit.
let sticky = m64 & 0b11111 != 0;
m64 >>= 5 - mexp;
m64 |= u64::from(sticky);
}
// Shift z to U21.43 as well.
let zshift = zexp + FIXED_FRAC_BITS - f16::SIG_BITS - adjbias;
let z64 = u64::from(zsig) << zshift;
let sub = mneg ^ zneg;
let rneg;
let r64 = if sub {
if m64 > z64 {
rneg = mneg;
m64.wrapping_sub(z64)
} else if m64 == z64 {
rneg = false;
m64.wrapping_sub(z64)
} else {
rneg = zneg;
z64.wrapping_sub(m64)
}
} else {
rneg = mneg;
m64 + z64
};
let sign = if rneg { -1.0 } else { 1.0 };
f16_from_u21_43(r64).copysign(sign)
}
/// Turn a U21.43 value into an f16 with positive sign.
fn f16_from_u21_43(mut r64: u64) -> f16 {
let extra_bits = 64 - 16;
let max_finite_lz = 64 - f16::SIG_BITS - extra_bits - 1; // 5
// Check for overflow to infinity after addition, return before checking lz.
if r64 & (u64::MAX << (64 - max_finite_lz)) != 0 {
return f16::INFINITY;
}
// Shift the fixed point to floating point. There are 5 leading zeros before the largest
// finite value's explicit one.
//
// We want `rexp` as one less than the actual value to be stored because it gets added to
// a value with the leading one set. This value and the shift are clamped so subnormals
// don't become normalized.
let exp_max_biased_m1 = f16::EXP_MAX.unsigned() + f16::EXP_BIAS - 1; // 29
let lz = r64.leading_zeros();
let rexp = (exp_max_biased_m1 + max_finite_lz).saturating_sub(lz);
let shift = exp_max_biased_m1 - rexp;
r64 <<= shift;
// Round up if the round bit (one past significand end) is set and any trailing bit is set,
// or if the preceding bit is set.
let round_bit = 1u64 << (extra_bits - 1);
let up_mask = ((1u64 << (extra_bits + 1)) - 1) & !round_bit;
let round_up = r64 & round_bit != 0 && r64 & up_mask != 0;
let round_up = u16::from(round_up);
// Truncate then round. Automatically accounts for subnormals with the unset explicit decimal
// bit, since `rexp` is one less than the actual biased value.
let mut r = (r64 >> extra_bits) as u16;
r += u16::cast_from(rexp) << f16::SIG_BITS;
r += round_up;
f16::from_bits(r)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_fixed() {
// Move 1.xx... floating point to 1.xx... fixed point
let shift_to_one = |x: u16| u64::from(x) << (FIXED_FRAC_BITS - f16::SIG_BITS);
let top_sig = f16::IMPLICIT_BIT;
let max_sig = f16::IMPLICIT_BIT | f16::SIG_MASK;
// Basic values
let one = shift_to_one(top_sig);
let max = shift_to_one(max_sig) << 15;
let inf = shift_to_one(max_sig + 1) << 15;
let min_norm = one >> 14;
let max_sub = shift_to_one(f16::SIG_MASK) >> 14;
let min_sub = shift_to_one(1) >> 14;
assert_biteq!(f16_from_u21_43(0), 0.0f16);
assert_biteq!(f16_from_u21_43(one), 1.0f16);
assert_biteq!(f16_from_u21_43(max), f16::MAX);
assert_biteq!(f16_from_u21_43(inf), f16::INFINITY);
assert_biteq!(f16_from_u21_43(min_norm), f16::MIN_POSITIVE_NORMAL);
assert_biteq!(f16_from_u21_43(max_sub), f16::from_bits(f16::SIG_MASK));
assert_biteq!(f16_from_u21_43(min_sub), f16::MIN_POSITIVE_SUBNORMAL);
// Masks centered around 1 to add a rounding
let mask_r = shift_to_one(0b1) >> 1; // round bit
let mask_rg = shift_to_one(0b11) >> 2; // round + guard
let mask_rgs = shift_to_one(0b111) >> 3; // round + guard + sticky
let mask_rs = shift_to_one(0b101) >> 3; // round + sticky
let mask_rs2 = shift_to_one(0b1000_0001) >> 8; // round + part of sticky
let signed_shift = |val: u64, shift: i32| {
if shift >= 0 {
val << shift
} else {
val >> -shift
}
};
let check_round = |fixed: u64, shift: i32, lsb_set: bool, down: f16, up: f16| {
// Masks that will cause rounding down
let mdown = if lsb_set { &[0][..] } else { &[0, mask_r][..] };
// Masks that will cause rounding up
let mup = if lsb_set {
&[mask_r, mask_rg, mask_rgs, mask_rs, mask_rs2][..]
} else {
&[mask_rg, mask_rgs, mask_rs, mask_rs2][..]
};
for (i, mask) in mdown.iter().enumerate() {
let bits = fixed | signed_shift(*mask, shift);
assert_biteq!(f16_from_u21_43(bits), down, "{bits:#066b} {i}");
}
for (i, mask) in mup.iter().enumerate() {
let bits = fixed | signed_shift(*mask, shift);
assert_biteq!(f16_from_u21_43(bits), up, "{bits:#066b} {i}");
}
};
check_round(one, 0, false, 1.0, 1.0f16.next_up());
check_round(max, 15, true, f16::MAX, f16::INFINITY);
check_round(
min_norm,
-14,
false,
f16::MIN_POSITIVE_NORMAL,
f16::MIN_POSITIVE_NORMAL.next_up(),
);
check_round(
max_sub,
-14,
true,
f16::MIN_POSITIVE_NORMAL.next_down(),
f16::MIN_POSITIVE_NORMAL,
);
check_round(
min_sub,
-14,
true,
f16::MIN_POSITIVE_SUBNORMAL,
f16::MIN_POSITIVE_SUBNORMAL.next_up(),
);
check_round(0, -14, false, 0.0, f16::MIN_POSITIVE_SUBNORMAL);
}
}
@@ -296,12 +296,15 @@ macro_rules! i {
cfg_if! {
if #[cfg(f16_enabled)] {
mod fmaf16;
// verify-sorted-start
pub use self::ceil::ceilf16;
pub use self::copysign::copysignf16;
pub use self::fabs::fabsf16;
pub use self::fdim::fdimf16;
pub use self::floor::floorf16;
pub use self::fmaf16::fmaf16;
pub use self::fmin_fmax::{fmaxf16, fminf16};
pub use self::fminimum_fmaximum::{fmaximumf16, fminimumf16};
pub use self::fminimum_fmaximum_num::{fmaximum_numf16, fminimum_numf16};
@@ -316,9 +319,6 @@ macro_rules! i {
pub use self::sqrt::sqrtf16;
pub use self::trunc::truncf16;
// verify-sorted-end
#[allow(unused_imports)]
pub(crate) use self::fma::fmaf16;
}
}
@@ -64,13 +64,13 @@ pub trait Float:
const MIN_POSITIVE_NORMAL: Self;
const MIN_POSITIVE_SUBNORMAL: Self;
/// The bitwidth of the float type
/// The bitwidth of the float type.
const BITS: u32;
/// The bitwidth of the significand
/// The bitwidth of the significand (does not include the implicit bit).
const SIG_BITS: u32;
/// The bitwidth of the exponent
/// The bitwidth of the exponent.
const EXP_BITS: u32 = Self::BITS - Self::SIG_BITS - 1;
/// The saturated (maximum bitpattern) value of the exponent, i.e. the infinite
@@ -79,7 +79,7 @@ pub trait Float:
/// This shifted fully right, use `EXP_MASK` for the shifted value.
const EXP_SAT: u32 = (1 << Self::EXP_BITS) - 1;
/// The exponent bias value
/// The exponent bias value.
const EXP_BIAS: u32 = Self::EXP_SAT >> 1;
/// Maximum unbiased exponent value.
@@ -69,9 +69,15 @@ pub unsafe fn unchecked_div_isize(x: isize, y: isize) -> isize {
}
}
// FIXME(msrv): `div_ceil` is stablein 1.73.
// FIXME(msrv): `div_ceil` is stable in 1.73.
pub fn div_ceil_u32(a: u32, b: u32) -> u32 {
let d = a / b;
let r = a % b;
if r > 0 { d + 1 } else { d }
}
// FIXME(msrv): unbounded shifts are stable in 1.87
#[allow(unused)]
pub fn unbounded_shr_u64(x: u64, shift: u32) -> u64 {
x.checked_shr(shift).unwrap_or(0)
}