diff --git a/lib/compiler_rt/udivmod.zig b/lib/compiler_rt/udivmod.zig index 44165c0a62..fdb338c21a 100644 --- a/lib/compiler_rt/udivmod.zig +++ b/lib/compiler_rt/udivmod.zig @@ -182,7 +182,7 @@ fn divwide(comptime T: type, _u1: T, _u0: T, v: T, r: *T) T { pub fn udivmod(comptime T: type, a_: T, b_: T, maybe_rem: ?*T) T { @setRuntimeSafety(compiler_rt.test_safety); const HalfT = HalveInt(T, false).HalfT; - const SignedT = std.meta.Int(.signed, @bitSizeOf(T)); + const half_bits = @bitSizeOf(HalfT); if (b_ > a_) { if (maybe_rem) |rem| { @@ -214,26 +214,85 @@ pub fn udivmod(comptime T: type, a_: T, b_: T, maybe_rem: ?*T) T { return @bitCast(q); } - // 0 <= shift <= 63 - const shift: Log2Int(T) = @clz(b[hi]) - @clz(a[hi]); - var af: T = @bitCast(a); - var bf = @as(T, @bitCast(b)) << shift; - q = @bitCast(@as(T, 0)); + // Large-divisor case: b[hi] != 0, so the quotient fits in one HalfT word. + // + // Trial quotient via divwide (Knuth Vol 2, Section 4.3.1): + // Normalize the divisor so its high half has the MSB set, then use divwide + // on the top bits to get a trial quotient that is at most 1 too large. + // This replaces the O(shift) bit-by-bit loop with O(1) operations. + const s: Log2Int(HalfT) = @intCast(@clz(b[hi])); - for (0..shift + 1) |_| { - q[lo] <<= 1; - // Branchless version of: - // if (af >= bf) { - // af -= bf; - // q[lo] |= 1; - // } - const s = @as(SignedT, @bitCast(bf -% af -% 1)) >> (@bitSizeOf(T) - 1); - q[lo] |= @intCast(s & 1); - af -= bf & @as(T, @bitCast(s)); - bf >>= 1; + if (s == 0) { + // b[hi] already has its MSB set, so b >= 2^(T_bits - 1). Since a >= b + // (we passed the b_ > a_ check), a >= 2^(T_bits - 1) too, meaning + // a[hi] also has its MSB set. Therefore a / b < 2, and the quotient + // is exactly 1. + q = @bitCast(@as(T, 0)); + q[lo] = 1; + if (maybe_rem) |rem| { + rem.* = a_ - b_; + } + return @bitCast(q); } + + // Normalize b: shift left by s so bn_hi has its MSB set. + const sr: Log2Int(HalfT) = @intCast(half_bits - @as( + std.math.IntFittingRange(0, half_bits), + @intCast(s), + )); + const bn_hi: HalfT = (b[hi] << s) | (b[lo] >> sr); + + // Trial numerator: the top (half_bits + s) bits of (a << s), as [a2:a1]. + // a2 < bn_hi is guaranteed since a2 < 2^s and bn_hi >= 2^(half_bits - 1). + const a2: HalfT = a[hi] >> sr; + const a1: HalfT = (a[hi] << s) | (a[lo] >> sr); + + // Trial quotient via divwide: q_hat = floor([a2:a1] / bn_hi). + // By Knuth's theorem (normalized divisor), q <= q_hat <= q + 1. + var r_tmp: HalfT = undefined; + var q_hat: HalfT = divwide(HalfT, a2, a1, bn_hi, &r_tmp); + + // Verify: q_hat * b must not exceed a. + // Compute the product using HalfT * HalfT -> T widening multiplications, + // which are native single-instruction ops when HalfT fits in a register + // (e.g. u64 * u64 -> u128 via mulq on x86_64, mul on aarch64). + // product = q_hat * [b[hi]:b[lo]] = [p_top : p_mid : p_lo] (3 half-words) + const prod_lo: T = @as(T, q_hat) * @as(T, b[lo]); + const prod_hi: T = @as(T, q_hat) * @as(T, b[hi]); + + const prod_lo_parts: [2]HalfT = @bitCast(prod_lo); + const prod_hi_parts: [2]HalfT = @bitCast(prod_hi); + + const mid_add = @addWithOverflow(prod_hi_parts[lo], prod_lo_parts[hi]); + var p_mid: HalfT = mid_add[0]; + const p_top: HalfT = prod_hi_parts[hi] +% @as(HalfT, mid_add[1]); + var p_lo: HalfT = prod_lo_parts[lo]; + + // If product > a, decrement q_hat (at most once, guaranteed by Knuth). + if (p_top > 0 or p_mid > a[hi] or (p_mid == a[hi] and p_lo > a[lo])) { + q_hat -= 1; + // Subtract b from the product for correct remainder computation. + // After correction, (q_hat * b) fits in T bits, so borrows into + // p_top cancel it to zero -- we only need [p_mid:p_lo]. + const sub_lo = @subWithOverflow(p_lo, b[lo]); + p_lo = sub_lo[0]; + const sub_mid = @subWithOverflow(p_mid, b[hi]); + const sub_mid2 = @subWithOverflow(sub_mid[0], @as(HalfT, sub_lo[1])); + p_mid = sub_mid2[0]; + } + + q = @bitCast(@as(T, 0)); + q[lo] = q_hat; + if (maybe_rem) |rem| { - rem.* = @bitCast(af); + // remainder = a - q_hat * b = [a[hi]:a[lo]] - [p_mid:p_lo] + // This subtraction is non-negative since q_hat <= true quotient. + const rem_lo = @subWithOverflow(a[lo], p_lo); + r[lo] = rem_lo[0]; + const rem_hi = @subWithOverflow(a[hi], p_mid); + const rem_hi2 = @subWithOverflow(rem_hi[0], @as(HalfT, rem_lo[1])); + r[hi] = rem_hi2[0]; + rem.* = @bitCast(r); } return @bitCast(q); }