From fe9b9956f87f1d4eead28a194eabc774e08581b6 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Mon, 20 Apr 2026 10:15:12 +0000 Subject: [PATCH] libm: Add `fmaf16` --- .../crates/api-list-common/src/lib.rs | 10 + .../etc/function-definitions.json | 7 + .../compiler-builtins/etc/function-list.txt | 1 + .../libm-test/benches/icount.rs | 1 + .../libm-test/src/generate/case_list.rs | 51 ++++ .../libm-test/src/mpfloat.rs | 2 +- .../libm-test/src/precision.rs | 2 + .../compiler-builtins/libm/src/libm_helper.rs | 1 + .../compiler-builtins/libm/src/math/fma.rs | 18 +- .../compiler-builtins/libm/src/math/fmaf16.rs | 258 ++++++++++++++++++ .../compiler-builtins/libm/src/math/mod.rs | 6 +- .../libm/src/math/support/float_traits.rs | 8 +- .../libm/src/math/support/mod.rs | 8 +- 13 files changed, 357 insertions(+), 16 deletions(-) create mode 100644 library/compiler-builtins/libm/src/math/fmaf16.rs diff --git a/library/compiler-builtins/crates/api-list-common/src/lib.rs b/library/compiler-builtins/crates/api-list-common/src/lib.rs index 3e5868e752bc..f6e5516a76fa 100644 --- a/library/compiler-builtins/crates/api-list-common/src/lib.rs +++ b/library/compiler-builtins/crates/api-list-common/src/lib.rs @@ -1081,6 +1081,16 @@ pub fn defined_in_compiler_builtins(self) -> bool { ], scope: OpScope::LibmPublic, }, + NestedOp { + // `(f16, f16, f16) -> f16` + rust_sig: Signature { + args: &[Ty::F16, Ty::F16, Ty::F16], + returns: &[Ty::F16], + }, + c_sig: None, + fn_list: &["fmaf16"], + scope: OpScope::LibmPublic, + }, NestedOp { // `(f32, f32, f32) -> f32` rust_sig: Signature { diff --git a/library/compiler-builtins/etc/function-definitions.json b/library/compiler-builtins/etc/function-definitions.json index 38d609da3fcf..f59bb7ceebb4 100644 --- a/library/compiler-builtins/etc/function-definitions.json +++ b/library/compiler-builtins/etc/function-definitions.json @@ -365,6 +365,13 @@ ], "type": "f128" }, + "fmaf16": { + "sources": [ + "libm/src/math/fmaf16.rs", + "libm/src/math/generic/fma.rs" + ], + "type": "f16" + }, "fmax": { "sources": [ "libm/src/math/fmin_fmax.rs", diff --git a/library/compiler-builtins/etc/function-list.txt b/library/compiler-builtins/etc/function-list.txt index f7a694d10f95..14f75a2c5783 100644 --- a/library/compiler-builtins/etc/function-list.txt +++ b/library/compiler-builtins/etc/function-list.txt @@ -54,6 +54,7 @@ floorf16 fma fmaf fmaf128 +fmaf16 fmax fmaxf fmaxf128 diff --git a/library/compiler-builtins/libm-test/benches/icount.rs b/library/compiler-builtins/libm-test/benches/icount.rs index f67f7b049d30..8e067f2da5e6 100644 --- a/library/compiler-builtins/libm-test/benches/icount.rs +++ b/library/compiler-builtins/libm-test/benches/icount.rs @@ -132,6 +132,7 @@ fn [< icount_bench_ $fn_name >](cases: Vec>) { icount_bench_fma, icount_bench_fmaf, icount_bench_fmaf128, + icount_bench_fmaf16, icount_bench_fmax, icount_bench_fmaxf, icount_bench_fmaxf128, diff --git a/library/compiler-builtins/libm-test/src/generate/case_list.rs b/library/compiler-builtins/libm-test/src/generate/case_list.rs index d3daf86843d1..2275787307d9 100644 --- a/library/compiler-builtins/libm-test/src/generate/case_list.rs +++ b/library/compiler-builtins/libm-test/src/generate/case_list.rs @@ -6,6 +6,8 @@ //! //! This is useful for adding regression tests or expected failures. +#[cfg(f16_enabled)] +use libm::hf16; #[cfg(f128_enabled)] use libm::hf128; use libm::{hf32, hf64}; @@ -943,6 +945,55 @@ fn floorf16_cases() -> Vec> { cases![] } +#[cfg(f16_enabled)] +fn fmaf16_cases() -> Vec> { + cases![ + // Subnormal result + ((hf16!("0x1p-11"), hf16!("0x1p-11"), 0.0), None,), + ((hf16!("0x1p-24"), hf16!("0x1p-24"), 0.0), None,), + // Failed during extensive tests + ( + ( + hf16!("-0x1.c4p-12"), + hf16!("0x1.22p-14"), + hf16!("-0x1.f4p-15"), + ), + hf16!("-0x1.f48p-15") + ), + // Examples from https://github.com/llvm/llvm-project/issues/128450 + ( + ( + hf16!("0x1.400p+8"), + hf16!("0x1.008p+7"), + hf16!("0x1.000p-24"), + ), + hf16!("0x1.40cp+15") + ), + ( + (hf16!("0x1.eb8p-12"), hf16!("0x1.9p-11"), hf16!("-0x1p-11"),), + None + ), + // Previous failures during testing + ((-569.0, -4.89, 65470.0), None), + ((-998.0, 0.02596, -998.0), None), + ((6e-8, 6e-8, -6.104e-5), None), + ((6e-8, 65300.0, 9.5e-7), None), + ((-569.0, -4.89, -0.0417), None), + ((6444.0, 0.003443, 0.003443), None), + ((6e-8, 6e-8, 6e-8), None), + ((6e-8, -1.0, 6e-8), None), + ((1.001, 65500.0, 6e-8), None), + ((1.001, 65500.0, 65500.0), None), + ((1.002, 65470.0, 0.0), None), + ((1.002, 65470.0, 65500.0), None), + ((0.0002216, -8.464e-5, 4.4e-5), None), + ((-56700.0, -2.082, -61120.0), None), + ((-475.3, -475.3, 60450.0), None), + ((-61120.0, -3.969, 20320.0), None), + ((6e-8, -6e-8, 0.0), None), + ] +} + fn fmaf_cases() -> Vec> { cases![ // Known rounding error for some implementations (notably MinGW) diff --git a/library/compiler-builtins/libm-test/src/mpfloat.rs b/library/compiler-builtins/libm-test/src/mpfloat.rs index c4f1ca193e58..69c6590d3763 100644 --- a/library/compiler-builtins/libm-test/src/mpfloat.rs +++ b/library/compiler-builtins/libm-test/src/mpfloat.rs @@ -357,7 +357,7 @@ fn run(this: &mut Self::MpTy, input: Self::RustArgs) -> Self::RustRet { expm1 | expm1f => exp_m1, fabs | fabsf => abs, fdim | fdimf | fdimf16 | fdimf128 => positive_diff, - fma | fmaf | fmaf128 => mul_add, + fmaf16 | fma | fmaf | fmaf128 => mul_add, fmax | fmaxf | fmaxf16 | fmaxf128 | fmaximum_num | fmaximum_numf | fmaximum_numf16 | fmaximum_numf128 => max, fmin | fminf | fminf16 | fminf128 | diff --git a/library/compiler-builtins/libm-test/src/precision.rs b/library/compiler-builtins/libm-test/src/precision.rs index 2034e89c71e4..967274caa964 100644 --- a/library/compiler-builtins/libm-test/src/precision.rs +++ b/library/compiler-builtins/libm-test/src/precision.rs @@ -549,6 +549,8 @@ impl MaybeOverride<(f64, i32)> for SpecialCase {} #[cfg(f128_enabled)] impl MaybeOverride<(f128, i32)> for SpecialCase {} +#[cfg(f16_enabled)] +impl MaybeOverride<(f16, f16, f16)> for SpecialCase {} impl MaybeOverride<(f32, f32, f32)> for SpecialCase {} impl MaybeOverride<(f64, f64, f64)> for SpecialCase {} #[cfg(f128_enabled)] diff --git a/library/compiler-builtins/libm/src/libm_helper.rs b/library/compiler-builtins/libm/src/libm_helper.rs index 1e6e0beb3e3e..f2c4daed561e 100644 --- a/library/compiler-builtins/libm/src/libm_helper.rs +++ b/library/compiler-builtins/libm/src/libm_helper.rs @@ -195,6 +195,7 @@ pub fn $func($($arg: $arg_typ),*) -> ($($ret_typ),*) { (fn fabs(x: f16) -> (f16); => fabsf16); (fn fdim(x: f16, y: f16) -> (f16); => fdimf16); (fn floor(x: f16) -> (f16); => floorf16); + (fn fma(x: f16, y: f16, z: f16) -> (f16); => fmaf16); (fn fmax(x: f16, y: f16) -> (f16); => fmaxf16); (fn fmaximum_num(x: f16, y: f16) -> (f16); => fmaximum_numf16); (fn fmaximumf16(x: f16, y: f16) -> (f16); => fmaximumf16); diff --git a/library/compiler-builtins/libm/src/math/fma.rs b/library/compiler-builtins/libm/src/math/fma.rs index e99d95cb6152..513921d530fa 100644 --- a/library/compiler-builtins/libm/src/math/fma.rs +++ b/library/compiler-builtins/libm/src/math/fma.rs @@ -4,13 +4,7 @@ use super::generic; use crate::support::Round; -// Placeholder so we can have `fmaf16` in the `Float` trait. -#[allow(unused)] -#[cfg(f16_enabled)] -#[cfg_attr(assert_no_panic, no_panic::no_panic)] -pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 { - unimplemented!() -} +/* See `fmaf16.rs` for that implementation */ /// Floating multiply add (f32) /// @@ -78,8 +72,12 @@ macro_rules! cases { (-1.0, -1.0, -1.0, 0.0), // Roundtrip + (<$f>::MAX, 1.0, 0.0, <$f>::MAX), + (<$f>::MAX, <$f>::MAX, 1.0, <$f>::INFINITY), (<$f>::MAX, 1.0, -<$f>::MAX, 0.0), (-<$f>::MAX, 1.0, <$f>::MAX, 0.0), + (<$f>::MIN_POSITIVE_NORMAL, 1.0, -<$f>::MIN_POSITIVE_NORMAL, 0.0), + (-<$f>::MIN_POSITIVE_NORMAL, 1.0, <$f>::MIN_POSITIVE_NORMAL, 0.0), (<$f>::MIN_POSITIVE_SUBNORMAL, 1.0, -<$f>::MIN_POSITIVE_SUBNORMAL, 0.0), (-<$f>::MIN_POSITIVE_SUBNORMAL, 1.0, <$f>::MIN_POSITIVE_SUBNORMAL, 0.0), (<$f>::MAX, 1.0, -<$f>::MAX, 0.0), @@ -110,6 +108,12 @@ fn check(f: fn(F, F, F) -> F, cases: &[(F, F, F, F)]) { } } + #[test] + #[cfg(f16_enabled)] + fn check_f16() { + check::(super::super::fmaf16, &cases!(f16)); + } + #[test] fn check_f32() { check::(fmaf, &cases!(f32)); diff --git a/library/compiler-builtins/libm/src/math/fmaf16.rs b/library/compiler-builtins/libm/src/math/fmaf16.rs new file mode 100644 index 000000000000..8d1c5bccf852 --- /dev/null +++ b/library/compiler-builtins/libm/src/math/fmaf16.rs @@ -0,0 +1,258 @@ +/* SPDX-License-Identifier: MIT OR Apache-2.0 + * origin: original implementation, 2026 (TG) */ + +use crate::support::{CastFrom, Float, Int, unbounded_shr_u64}; + +/// We use a a U21.43 fixed-point representation when needed. +const FIXED_FRAC_BITS: u32 = 43; + +/// Floating multiply add (f16) +/// +/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). +#[cfg_attr(assert_no_panic, no_panic::no_panic)] +pub fn fmaf16(x: f16, y: f16, z: f16) -> f16 { + let ix = x.to_bits() & !f16::SIGN_MASK; + let iy = y.to_bits() & !f16::SIGN_MASK; + let iz = z.to_bits() & !f16::SIGN_MASK; + + let xneg = x.is_sign_negative(); + let yneg = y.is_sign_negative(); + let zneg = z.is_sign_negative(); + + let mneg = xneg ^ yneg; + + if ix == 0 || ix >= f16::EXP_MASK || iy == 0 || iy >= f16::EXP_MASK { + // Value will overflow, defer to non-fused operations. + return x * y + z; + } + + if iz == 0 { + // Empty add component means we only need to multiply. + return x * y; + } + + if iz >= f16::EXP_MASK { + // `z` is NaN or infinity, which sets the result. + return z; + } + + let mut xexp = x.ex(); + let mut yexp = y.ex(); + let mut zexp = z.ex(); + + let mut xsig = ix & f16::SIG_MASK; + let mut ysig = iy & f16::SIG_MASK; + let mut zsig = iz & f16::SIG_MASK; + + // If not subnormal, set the implicit bit + if xexp != 0 { + xsig |= f16::IMPLICIT_BIT; + } + if yexp != 0 { + ysig |= f16::IMPLICIT_BIT; + } + if zexp != 0 { + zsig |= f16::IMPLICIT_BIT; + } + + // A biased exponent of 1 (min normal) and 0 (subnormal) have the same real exponent, so + // adjust for this. Bias is now 14 rather than 15. + xexp = xexp.saturating_sub(1); + yexp = yexp.saturating_sub(1); + zexp = zexp.saturating_sub(1); + let adjbias = f16::EXP_BIAS - 1; + + // Exponent after multiplication. Bias doubles to 28. + let mexp = xexp + yexp; + let mbias = adjbias * 2; + + // Exit now if we know the result will overflow. We need to keep one beyond the infinite + // exponent in case the addition rounds down to a finite number. + // + // Note that `EXP_MAX` (i.e. max finite) represents infinity here because our values are + // acting with a bias of 14. + let inf_exp = mbias + f16::EXP_MAX.unsigned(); + if mexp > inf_exp + 1 { + if mneg { + return f16::NEG_INFINITY; + } else { + return f16::INFINITY; + } + } + + // Multiplication moves the explicit 1 from the 11th bit to the 22nd bit. + let m = u32::from(xsig) * u32::from(ysig); + let mut m64 = u64::from(m); + + // The entire dynamic range of an `f16` fits into a `u64`. Shift based on the exponent to + // create a U21.43 fixed-point value. At the maximum exponent, there are five zeros before + // the explicit leading 1 (intentional so this truncates to the final repr). + if let Some(mshift) = mexp.checked_sub(5) { + debug_assert_eq!( + unbounded_shr_u64(m64, 64 - mshift), + 0, + "data shifted out {m} {mshift}" + ); + m64 <<= mshift; + } else { + // The lower few bits here would be on the order of 2^-43, which is too small to show up + // in a result significand. Just squash them to a sticky bit. + let sticky = m64 & 0b11111 != 0; + m64 >>= 5 - mexp; + m64 |= u64::from(sticky); + } + + // Shift z to U21.43 as well. + let zshift = zexp + FIXED_FRAC_BITS - f16::SIG_BITS - adjbias; + let z64 = u64::from(zsig) << zshift; + + let sub = mneg ^ zneg; + + let rneg; + let r64 = if sub { + if m64 > z64 { + rneg = mneg; + m64.wrapping_sub(z64) + } else if m64 == z64 { + rneg = false; + m64.wrapping_sub(z64) + } else { + rneg = zneg; + z64.wrapping_sub(m64) + } + } else { + rneg = mneg; + m64 + z64 + }; + + let sign = if rneg { -1.0 } else { 1.0 }; + f16_from_u21_43(r64).copysign(sign) +} + +/// Turn a U21.43 value into an f16 with positive sign. +fn f16_from_u21_43(mut r64: u64) -> f16 { + let extra_bits = 64 - 16; + let max_finite_lz = 64 - f16::SIG_BITS - extra_bits - 1; // 5 + + // Check for overflow to infinity after addition, return before checking lz. + if r64 & (u64::MAX << (64 - max_finite_lz)) != 0 { + return f16::INFINITY; + } + + // Shift the fixed point to floating point. There are 5 leading zeros before the largest + // finite value's explicit one. + // + // We want `rexp` as one less than the actual value to be stored because it gets added to + // a value with the leading one set. This value and the shift are clamped so subnormals + // don't become normalized. + let exp_max_biased_m1 = f16::EXP_MAX.unsigned() + f16::EXP_BIAS - 1; // 29 + let lz = r64.leading_zeros(); + let rexp = (exp_max_biased_m1 + max_finite_lz).saturating_sub(lz); + let shift = exp_max_biased_m1 - rexp; + r64 <<= shift; + + // Round up if the round bit (one past significand end) is set and any trailing bit is set, + // or if the preceding bit is set. + let round_bit = 1u64 << (extra_bits - 1); + let up_mask = ((1u64 << (extra_bits + 1)) - 1) & !round_bit; + let round_up = r64 & round_bit != 0 && r64 & up_mask != 0; + let round_up = u16::from(round_up); + + // Truncate then round. Automatically accounts for subnormals with the unset explicit decimal + // bit, since `rexp` is one less than the actual biased value. + let mut r = (r64 >> extra_bits) as u16; + r += u16::cast_from(rexp) << f16::SIG_BITS; + r += round_up; + + f16::from_bits(r) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_fixed() { + // Move 1.xx... floating point to 1.xx... fixed point + let shift_to_one = |x: u16| u64::from(x) << (FIXED_FRAC_BITS - f16::SIG_BITS); + let top_sig = f16::IMPLICIT_BIT; + let max_sig = f16::IMPLICIT_BIT | f16::SIG_MASK; + + // Basic values + let one = shift_to_one(top_sig); + let max = shift_to_one(max_sig) << 15; + let inf = shift_to_one(max_sig + 1) << 15; + let min_norm = one >> 14; + let max_sub = shift_to_one(f16::SIG_MASK) >> 14; + let min_sub = shift_to_one(1) >> 14; + + assert_biteq!(f16_from_u21_43(0), 0.0f16); + assert_biteq!(f16_from_u21_43(one), 1.0f16); + assert_biteq!(f16_from_u21_43(max), f16::MAX); + assert_biteq!(f16_from_u21_43(inf), f16::INFINITY); + assert_biteq!(f16_from_u21_43(min_norm), f16::MIN_POSITIVE_NORMAL); + assert_biteq!(f16_from_u21_43(max_sub), f16::from_bits(f16::SIG_MASK)); + assert_biteq!(f16_from_u21_43(min_sub), f16::MIN_POSITIVE_SUBNORMAL); + + // Masks centered around 1 to add a rounding + let mask_r = shift_to_one(0b1) >> 1; // round bit + let mask_rg = shift_to_one(0b11) >> 2; // round + guard + let mask_rgs = shift_to_one(0b111) >> 3; // round + guard + sticky + let mask_rs = shift_to_one(0b101) >> 3; // round + sticky + let mask_rs2 = shift_to_one(0b1000_0001) >> 8; // round + part of sticky + + let signed_shift = |val: u64, shift: i32| { + if shift >= 0 { + val << shift + } else { + val >> -shift + } + }; + + let check_round = |fixed: u64, shift: i32, lsb_set: bool, down: f16, up: f16| { + // Masks that will cause rounding down + let mdown = if lsb_set { &[0][..] } else { &[0, mask_r][..] }; + // Masks that will cause rounding up + let mup = if lsb_set { + &[mask_r, mask_rg, mask_rgs, mask_rs, mask_rs2][..] + } else { + &[mask_rg, mask_rgs, mask_rs, mask_rs2][..] + }; + + for (i, mask) in mdown.iter().enumerate() { + let bits = fixed | signed_shift(*mask, shift); + assert_biteq!(f16_from_u21_43(bits), down, "{bits:#066b} {i}"); + } + + for (i, mask) in mup.iter().enumerate() { + let bits = fixed | signed_shift(*mask, shift); + assert_biteq!(f16_from_u21_43(bits), up, "{bits:#066b} {i}"); + } + }; + + check_round(one, 0, false, 1.0, 1.0f16.next_up()); + check_round(max, 15, true, f16::MAX, f16::INFINITY); + check_round( + min_norm, + -14, + false, + f16::MIN_POSITIVE_NORMAL, + f16::MIN_POSITIVE_NORMAL.next_up(), + ); + check_round( + max_sub, + -14, + true, + f16::MIN_POSITIVE_NORMAL.next_down(), + f16::MIN_POSITIVE_NORMAL, + ); + check_round( + min_sub, + -14, + true, + f16::MIN_POSITIVE_SUBNORMAL, + f16::MIN_POSITIVE_SUBNORMAL.next_up(), + ); + check_round(0, -14, false, 0.0, f16::MIN_POSITIVE_SUBNORMAL); + } +} diff --git a/library/compiler-builtins/libm/src/math/mod.rs b/library/compiler-builtins/libm/src/math/mod.rs index fe9ef580ca67..809d5e5d1809 100644 --- a/library/compiler-builtins/libm/src/math/mod.rs +++ b/library/compiler-builtins/libm/src/math/mod.rs @@ -296,12 +296,15 @@ macro_rules! i { cfg_if! { if #[cfg(f16_enabled)] { + mod fmaf16; + // verify-sorted-start pub use self::ceil::ceilf16; pub use self::copysign::copysignf16; pub use self::fabs::fabsf16; pub use self::fdim::fdimf16; pub use self::floor::floorf16; + pub use self::fmaf16::fmaf16; pub use self::fmin_fmax::{fmaxf16, fminf16}; pub use self::fminimum_fmaximum::{fmaximumf16, fminimumf16}; pub use self::fminimum_fmaximum_num::{fmaximum_numf16, fminimum_numf16}; @@ -316,9 +319,6 @@ macro_rules! i { pub use self::sqrt::sqrtf16; pub use self::trunc::truncf16; // verify-sorted-end - - #[allow(unused_imports)] - pub(crate) use self::fma::fmaf16; } } diff --git a/library/compiler-builtins/libm/src/math/support/float_traits.rs b/library/compiler-builtins/libm/src/math/support/float_traits.rs index 78fcb304547b..7bc9734a2099 100644 --- a/library/compiler-builtins/libm/src/math/support/float_traits.rs +++ b/library/compiler-builtins/libm/src/math/support/float_traits.rs @@ -64,13 +64,13 @@ pub trait Float: const MIN_POSITIVE_NORMAL: Self; const MIN_POSITIVE_SUBNORMAL: Self; - /// The bitwidth of the float type + /// The bitwidth of the float type. const BITS: u32; - /// The bitwidth of the significand + /// The bitwidth of the significand (does not include the implicit bit). const SIG_BITS: u32; - /// The bitwidth of the exponent + /// The bitwidth of the exponent. const EXP_BITS: u32 = Self::BITS - Self::SIG_BITS - 1; /// The saturated (maximum bitpattern) value of the exponent, i.e. the infinite @@ -79,7 +79,7 @@ pub trait Float: /// This shifted fully right, use `EXP_MASK` for the shifted value. const EXP_SAT: u32 = (1 << Self::EXP_BITS) - 1; - /// The exponent bias value + /// The exponent bias value. const EXP_BIAS: u32 = Self::EXP_SAT >> 1; /// Maximum unbiased exponent value. diff --git a/library/compiler-builtins/libm/src/math/support/mod.rs b/library/compiler-builtins/libm/src/math/support/mod.rs index 8bca2e60c7df..61999a1f9405 100644 --- a/library/compiler-builtins/libm/src/math/support/mod.rs +++ b/library/compiler-builtins/libm/src/math/support/mod.rs @@ -69,9 +69,15 @@ pub unsafe fn unchecked_div_isize(x: isize, y: isize) -> isize { } } -// FIXME(msrv): `div_ceil` is stablein 1.73. +// FIXME(msrv): `div_ceil` is stable in 1.73. pub fn div_ceil_u32(a: u32, b: u32) -> u32 { let d = a / b; let r = a % b; if r > 0 { d + 1 } else { d } } + +// FIXME(msrv): unbounded shifts are stable in 1.87 +#[allow(unused)] +pub fn unbounded_shr_u64(x: u64, shift: u32) -> u64 { + x.checked_shr(shift).unwrap_or(0) +}