stage2-wasm: sat ops

This commit is contained in:
Pavel Verigo
2026-04-08 03:47:24 +02:00
parent aa7874657b
commit f2a842db5c
2 changed files with 121 additions and 121 deletions
+21 -58
View File
@@ -3484,20 +3484,20 @@ fn intAddSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
defer rhs_is_neg.free(cg);
const min_val = try cg.intMinValue(int_ty);
try cg.emitWValue(min_val);
try cg.emitWValue(max_val);
try cg.lowerToStack(min_val);
try cg.lowerToStack(max_val);
try cg.emitWValue(rhs_is_neg);
try cg.addTag(.select);
try cg.emitWValue(op_val);
try cg.lowerToStack(op_val);
const overflow_cmp = try cg.intCmp(int_ty, .lt, op_val, lhs);
const is_overflow = try cg.intCmp(.u32, .neq, rhs_is_neg, overflow_cmp);
try cg.emitWValue(is_overflow);
try cg.addTag(.select);
return .stack;
} else {
try cg.emitWValue(max_val);
try cg.emitWValue(op_val);
try cg.lowerToStack(max_val);
try cg.lowerToStack(op_val);
const is_overflow = try cg.intCmp(int_ty, .lt, op_val, lhs);
try cg.emitWValue(is_overflow);
@@ -3518,12 +3518,12 @@ fn intSubSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
const max_val = try cg.intMaxValue(int_ty);
const min_val = try cg.intMinValue(int_ty);
try cg.emitWValue(max_val);
try cg.emitWValue(min_val);
try cg.lowerToStack(max_val);
try cg.lowerToStack(min_val);
try cg.emitWValue(rhs_is_neg);
try cg.addTag(.select);
try cg.emitWValue(op_val);
try cg.lowerToStack(op_val);
const overflow_cmp = try cg.intCmp(int_ty, .gt, op_val, lhs);
const is_overflow = try cg.intCmp(.u32, .neq, rhs_is_neg, overflow_cmp);
try cg.emitWValue(is_overflow);
@@ -3532,8 +3532,8 @@ fn intSubSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
} else {
const zero = try cg.intZeroValue(int_ty);
try cg.emitWValue(zero);
try cg.emitWValue(op_val);
try cg.lowerToStack(zero);
try cg.lowerToStack(op_val);
const is_overflow = try cg.intCmp(int_ty, .lt, lhs, rhs);
try cg.emitWValue(is_overflow);
try cg.addTag(.select);
@@ -3542,43 +3542,6 @@ fn intSubSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
}
fn intMulSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError!WValue {
// Remove when > 128 int ops will be implemented in backend
if (int_ty.bits == 128) {
if (!int_ty.is_signed) {
return cg.fail("TODO: mul_sat for unsigned 128-bit integers", .{});
}
const overflow_ret = try cg.allocStack(Type.i32);
const ret = try cg.callIntrinsic(
.__muloti4,
&[_]InternPool.Index{ .i128_type, .i128_type, .usize_type },
Type.i128,
&.{ lhs, rhs, overflow_ret },
);
try cg.lowerToStack(ret);
const xor = try cg.intXor(int_ty, lhs, rhs);
const sign_v = try cg.intShr(int_ty, xor, .{ .imm32 = 127 });
// xor ~@as(u127, 0)
try cg.emitWValue(sign_v);
const lsb = try cg.load(sign_v, Type.u64, 0);
_ = try cg.intXor(.u64, lsb, .{ .imm64 = ~@as(u64, 0) });
try cg.store(.stack, .stack, Type.u64, sign_v.offset());
try cg.emitWValue(sign_v);
const msb = try cg.load(sign_v, Type.u64, 8);
_ = try cg.intXor(.u64, msb, .{ .imm64 = ~@as(u64, 0) >> 1 });
try cg.store(.stack, .stack, Type.u64, sign_v.offset() + 8);
try cg.lowerToStack(sign_v);
_ = try cg.load(overflow_ret, Type.i32, 0);
try cg.addTag(.i32_eqz);
try cg.addTag(.select);
return .stack;
}
const ext_ty: IntType = .{ .is_signed = int_ty.is_signed, .bits = int_ty.bits * 2 };
const lhs_ext = try cg.intCast(ext_ty, int_ty, lhs);
@@ -3594,10 +3557,10 @@ fn intMulSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
if (int_ty.is_signed) {
const min_val = try cg.intMinValue(int_ty);
try cg.emitWValue(min_val);
try cg.lowerToStack(min_val);
try cg.emitWValue(max_val);
try cg.emitWValue(op_val);
try cg.lowerToStack(max_val);
try cg.lowerToStack(op_val);
const max_ext = try cg.intCast(ext_ty, int_ty, max_val);
const ov_pos = try cg.intCmp(ext_ty, .lt, max_ext, mul_ext);
try cg.emitWValue(ov_pos);
@@ -3605,12 +3568,12 @@ fn intMulSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
const min_ext = try cg.intCast(ext_ty, int_ty, min_val);
const ov_neg = try cg.intCmp(ext_ty, .gt, min_ext, mul_ext);
try cg.emitWValue(ov_neg);
try cg.lowerToStack(ov_neg);
try cg.addTag(.select);
return .stack;
} else {
try cg.emitWValue(max_val);
try cg.emitWValue(op_val);
try cg.lowerToStack(max_val);
try cg.lowerToStack(op_val);
const max_ext = try cg.intCast(ext_ty, int_ty, max_val);
const is_overflow = try cg.intCmp(ext_ty, .lt, max_ext, mul_ext);
try cg.emitWValue(is_overflow);
@@ -3633,20 +3596,20 @@ fn intShlSat(cg: *CodeGen, int_ty: IntType, lhs: WValue, rhs: WValue) InnerError
const zero = try cg.intZeroValue(int_ty);
const min_val = try cg.intMinValue(int_ty);
try cg.emitWValue(min_val);
try cg.emitWValue(max_val);
try cg.lowerToStack(min_val);
try cg.lowerToStack(max_val);
const lhs_is_neg = try cg.intCmp(int_ty, .lt, lhs, zero);
try cg.emitWValue(lhs_is_neg);
try cg.addTag(.select);
try cg.emitWValue(op_val);
try cg.lowerToStack(op_val);
const is_overflow = try cg.intCmp(int_ty, .neq, check_val, lhs);
try cg.emitWValue(is_overflow);
try cg.addTag(.select);
return .stack;
} else {
try cg.emitWValue(max_val);
try cg.emitWValue(op_val);
try cg.lowerToStack(max_val);
try cg.lowerToStack(op_val);
const is_overflow = try cg.intCmp(int_ty, .neq, check_val, lhs);
try cg.emitWValue(is_overflow);
try cg.addTag(.select);
+100 -63
View File
@@ -4,6 +4,14 @@ const minInt = std.math.minInt;
const maxInt = std.math.maxInt;
const expect = std.testing.expect;
fn testSatAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs +| rhs) == expected);
var x = lhs;
x +|= rhs;
try expect(x == expected);
}
test "saturating add" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -28,32 +36,23 @@ test "saturating add" {
try testSatAdd(u2, 3, 2, 3);
try testSatAdd(u3, 7, 1, 7);
}
fn testSatAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs +| rhs) == expected);
var x = lhs;
x +|= rhs;
try expect(x == expected);
}
};
try S.doTheTest();
try comptime S.doTheTest();
try comptime S.testSatAdd(comptime_int, 0, 0, 0);
try comptime S.testSatAdd(comptime_int, -1, 1, 0);
try comptime S.testSatAdd(comptime_int, 3, 2, 5);
try comptime S.testSatAdd(comptime_int, -3, -2, -5);
try comptime S.testSatAdd(comptime_int, 3, -2, 1);
try comptime S.testSatAdd(comptime_int, -3, 2, -1);
try comptime S.testSatAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512);
try comptime S.testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501);
try comptime testSatAdd(comptime_int, 0, 0, 0);
try comptime testSatAdd(comptime_int, -1, 1, 0);
try comptime testSatAdd(comptime_int, 3, 2, 5);
try comptime testSatAdd(comptime_int, -3, -2, -5);
try comptime testSatAdd(comptime_int, 3, -2, 1);
try comptime testSatAdd(comptime_int, -3, 2, -1);
try comptime testSatAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512);
try comptime testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501);
}
test "saturating add 128bit" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
@@ -65,19 +64,20 @@ test "saturating add 128bit" {
try testSatAdd(i128, minInt(i128), maxInt(i128), -1);
try testSatAdd(u128, maxInt(u128), 1, maxInt(u128));
}
fn testSatAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs +| rhs) == expected);
var x = lhs;
x +|= rhs;
try expect(x == expected);
}
};
try S.doTheTest();
try comptime S.doTheTest();
}
fn testSatSub(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs -| rhs) == expected);
var x = lhs;
x -|= rhs;
try expect(x == expected);
}
test "saturating subtraction" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -101,32 +101,23 @@ test "saturating subtraction" {
try testSatSub(u8, 10, 3, 7);
try testSatSub(u8, 0, 255, 0);
}
fn testSatSub(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs -| rhs) == expected);
var x = lhs;
x -|= rhs;
try expect(x == expected);
}
};
try S.doTheTest();
try comptime S.doTheTest();
try comptime S.testSatSub(comptime_int, 0, 0, 0);
try comptime S.testSatSub(comptime_int, 1, 1, 0);
try comptime S.testSatSub(comptime_int, 3, 2, 1);
try comptime S.testSatSub(comptime_int, -3, -2, -1);
try comptime S.testSatSub(comptime_int, 3, -2, 5);
try comptime S.testSatSub(comptime_int, -3, 2, -5);
try comptime S.testSatSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602);
try comptime S.testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515);
try comptime testSatSub(comptime_int, 0, 0, 0);
try comptime testSatSub(comptime_int, 1, 1, 0);
try comptime testSatSub(comptime_int, 3, 2, 1);
try comptime testSatSub(comptime_int, -3, -2, -1);
try comptime testSatSub(comptime_int, 3, -2, 5);
try comptime testSatSub(comptime_int, -3, 2, -5);
try comptime testSatSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602);
try comptime testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515);
}
test "saturating subtraction 128bit" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
@@ -138,14 +129,6 @@ test "saturating subtraction 128bit" {
try testSatSub(i128, minInt(i128), -maxInt(i128), -1);
try testSatSub(u128, 0, maxInt(u128), 0);
}
fn testSatSub(comptime T: type, lhs: T, rhs: T, expected: T) !void {
try expect((lhs -| rhs) == expected);
var x = lhs;
x -|= rhs;
try expect(x == expected);
}
};
try S.doTheTest();
@@ -257,7 +240,6 @@ test "saturating mul i64, i128" {
test "saturating multiplication" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
@@ -294,6 +276,14 @@ test "saturating multiplication" {
try comptime testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556);
}
fn testSatShl(comptime Lhs: type, lhs: Lhs, comptime Rhs: type, rhs: Rhs, expected: Lhs) !void {
try expect((lhs <<| rhs) == expected);
var x = lhs;
x <<|= rhs;
try expect(x == expected);
}
test "saturating shift-left" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -320,23 +310,15 @@ test "saturating shift-left" {
try testSatShl(u8, 0, u4, 8, 0);
try testSatShl(u8, 3, u4, 8, maxInt(u8));
}
fn testSatShl(comptime Lhs: type, lhs: Lhs, comptime Rhs: type, rhs: Rhs, expected: Lhs) !void {
try expect((lhs <<| rhs) == expected);
var x = lhs;
x <<|= rhs;
try expect(x == expected);
}
};
try S.doTheTest();
try comptime S.doTheTest();
try comptime S.testSatShl(comptime_int, 0, comptime_int, 0, 0);
try comptime S.testSatShl(comptime_int, 1, comptime_int, 2, 4);
try comptime S.testSatShl(comptime_int, 13, comptime_int, 150, 18554220005177478453757717602843436772975706112);
try comptime S.testSatShl(comptime_int, -582769, comptime_int, 180, -893090893854873184096635538665358532628308979495815656505344);
try comptime testSatShl(comptime_int, 0, comptime_int, 0, 0);
try comptime testSatShl(comptime_int, 1, comptime_int, 2, 4);
try comptime testSatShl(comptime_int, 13, comptime_int, 150, 18554220005177478453757717602843436772975706112);
try comptime testSatShl(comptime_int, -582769, comptime_int, 180, -893090893854873184096635538665358532628308979495815656505344);
}
test "saturating shift-left large rhs" {
@@ -344,7 +326,6 @@ test "saturating shift-left large rhs" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
{
var lhs: u8 = undefined;
@@ -386,3 +367,59 @@ test "saturating shl uses the LHS type" {
try expect((1 <<| @as(u8, 200)) == 1606938044258990275541962092341162602522202993782792835301376);
}
test "sat add > 128 bits" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
try testSatAdd(u140, 0, 0, 0);
try testSatAdd(u140, maxInt(u140), 1, maxInt(u140));
try testSatAdd(u200, 1 << 150, 1 << 20, (1 << 150) + (1 << 20));
try testSatAdd(u200, maxInt(u200), maxInt(u200), maxInt(u200));
try testSatAdd(i140, minInt(i140), -1, minInt(i140));
try testSatAdd(i140, maxInt(i140), 1, maxInt(i140));
try testSatAdd(i200, -1 << 150, 1 << 149, -1 << 149);
try testSatAdd(i200, maxInt(i200), maxInt(i200), maxInt(i200));
}
test "sat sub > 128 bits" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
try testSatSub(u140, 0, 1, 0);
try testSatSub(u140, maxInt(u140), maxInt(u140), 0);
try testSatSub(u200, 1 << 150, 1 << 20, (1 << 150) - (1 << 20));
try testSatSub(u200, maxInt(u200), 0, maxInt(u200));
try testSatSub(i140, minInt(i140), 1, minInt(i140));
try testSatSub(i140, maxInt(i140), -1, maxInt(i140));
try testSatSub(i200, -1 << 150, 1 << 149, -3 << 149);
try testSatSub(i200, 0, minInt(i200), maxInt(i200));
}
test "sat mul > 128 bits" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
try testSatMul(u140, 0, maxInt(u140), 0);
try testSatMul(u140, 1 << 70, 1 << 69, 1 << 139);
try testSatMul(u200, maxInt(u200), 2, maxInt(u200));
try testSatMul(u200, maxInt(u200) - 1, 1, maxInt(u200) - 1);
try testSatMul(i140, -1, maxInt(i140), -maxInt(i140));
try testSatMul(i140, minInt(i140), -1, maxInt(i140));
try testSatMul(i200, 1 << 100, 1 << 99, maxInt(i200));
try testSatMul(i200, -1 << 150, 1 << 30, -1 << 180);
}
test "sat shl > 128 bits" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
try testSatShl(u140, 0, u8, 17, 0);
try testSatShl(u140, 1 << 100, u8, 20, 1 << 120);
try testSatShl(u200, maxInt(u200), u8, 1, maxInt(u200));
try testSatShl(u200, 1 << 199, u8, 1, maxInt(u200));
try testSatShl(i140, 0, u8, 17, 0);
try testSatShl(i140, 1 << 100, u8, 38, 1 << 138);
try testSatShl(i140, 1 << 100, u8, 39, maxInt(i140));
try testSatShl(i200, minInt(i200) + 1, u8, 1, minInt(i200));
}