Simplify raw_eq to Transmute+Eq for sizes with primitives

This commit is contained in:
Scott McMurray
2026-05-10 19:28:57 -07:00
parent 52ecad938e
commit c468ee3386
8 changed files with 142 additions and 23 deletions
@@ -1,11 +1,12 @@
//! Performs various peephole optimizations.
use rustc_abi::ExternAbi;
use rustc_abi::{ExternAbi, Integer};
use rustc_hir::{LangItem, find_attr};
use rustc_index::IndexVec;
use rustc_middle::bug;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::ValidityRequirement;
use rustc_middle::ty::layout::{IntegerExt, ValidityRequirement};
use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, layout};
use rustc_span::{Symbol, sym};
@@ -33,10 +34,10 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
if !preserve_ub_checks {
SimplifyUbCheck { tcx }.visit_body(body);
}
let ctx = InstSimplifyContext {
let mut ctx = InstSimplifyContext {
tcx,
local_decls: &body.local_decls,
typing_env: body.typing_env(tcx),
local_decls: &mut body.local_decls,
};
for block in body.basic_blocks.as_mut() {
for statement in block.statements.iter_mut() {
@@ -55,6 +56,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let terminator = block.terminator.as_mut().unwrap();
ctx.simplify_primitive_clone(terminator, &mut block.statements);
ctx.simplify_size_or_align_of_val(terminator, &mut block.statements);
ctx.simplify_raw_eq(terminator, &mut block.statements);
ctx.simplify_intrinsic_assert(terminator);
ctx.simplify_nounwind_call(terminator);
simplify_duplicate_switch_targets(terminator);
@@ -68,7 +70,7 @@ fn is_required(&self) -> bool {
struct InstSimplifyContext<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
local_decls: &'a LocalDecls<'tcx>,
local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
typing_env: ty::TypingEnv<'tcx>,
}
@@ -318,6 +320,63 @@ fn simplify_size_or_align_of_val(
}
}
/// Simplify `raw_eq` intrinsic calls to `Eq` when the type has the size of a primitive.
///
/// For example, replace `raw_eq::<[u8; 4]>(a, b)` with `Eq(Transmute(a), Transmute(b))`.
fn simplify_raw_eq(
&mut self,
terminator: &mut Terminator<'tcx>,
statements: &mut Vec<Statement<'tcx>>,
) {
let tcx = self.tcx;
let source_info = terminator.source_info;
let span = source_info.span;
if let TerminatorKind::Call {
func, args, destination, target: Some(destination_block), ..
} = &terminator.kind
&& args.len() == 2
&& let Some((fn_def_id, generics)) = func.const_fn_def()
&& tcx.is_intrinsic(fn_def_id, sym::raw_eq)
&& let generic_ty = generics.type_at(0)
&& let Ok(layout) = tcx.layout_of(self.typing_env.as_query_input(generic_ty))
&& let Ok(integer) = Integer::from_size(layout.size)
{
let ref_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, generic_ty);
let uint_ty = integer.to_ty(tcx, false);
let mut transmute_operand = |op: &Operand<'tcx>| -> Operand<'tcx> {
let ref_local = self.local_decls.push(LocalDecl::new(ref_ty, span));
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
Place::from(ref_local),
Rvalue::Use(op.clone(), WithRetag::Yes),
))),
));
let place = Place::from(ref_local).project_deeper(&[ProjectionElem::Deref], tcx);
let int_local = self.local_decls.push(LocalDecl::new(uint_ty, span));
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
Place::from(int_local),
Rvalue::Cast(CastKind::Transmute, Operand::Copy(place), uint_ty),
))),
));
Operand::Move(Place::from(int_local))
};
let lhs_op = transmute_operand(&args[0].node);
let rhs_op = transmute_operand(&args[1].node);
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
*destination,
Rvalue::BinaryOp(BinOp::Eq, Box::new((lhs_op, rhs_op))),
))),
));
terminator.kind = TerminatorKind::Goto { target: *destination_block };
}
}
fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else {
return;
@@ -0,0 +1,25 @@
- // MIR for `inner_array` before InstSimplify-after-simplifycfg
+ // MIR for `inner_array` after InstSimplify-after-simplifycfg
fn inner_array(_1: &&[i32; 2], _2: &&[i32; 2]) -> bool {
let mut _0: bool;
+ let mut _3: &[i32; 2];
+ let mut _4: u64;
+ let mut _5: &[i32; 2];
+ let mut _6: u64;
bb0: {
- _0 = raw_eq::<[i32; 2]>(copy (*_1), copy (*_2)) -> [return: bb1, unwind unreachable];
+ _3 = copy (*_1);
+ _4 = copy (*_3) as u64 (Transmute);
+ _5 = copy (*_2);
+ _6 = copy (*_5) as u64 (Transmute);
+ _0 = Eq(move _4, move _6);
+ goto -> bb1;
}
bb1: {
return;
}
}
+27
View File
@@ -0,0 +1,27 @@
//@ test-mir-pass: InstSimplify-after-simplifycfg
#![crate_type = "lib"]
#![feature(core_intrinsics)]
#![feature(custom_mir)]
// Custom MIR so we can get an argument that's not just a local directly
use std::intrinsics::mir::*;
use std::intrinsics::raw_eq;
// EMIT_MIR raw_eq.inner_array.InstSimplify-after-simplifycfg.diff
#[custom_mir(dialect = "runtime")]
pub fn inner_array(a: &&[i32; 2], b: &&[i32; 2]) -> bool {
// CHECK-LABEL: fn inner_array(_1: &&[i32; 2], _2: &&[i32; 2]) -> bool
// CHECK: [[AREF:_.+]] = copy (*_1);
// CHECK: [[AINT:_.+]] = copy (*[[AREF]]) as u64 (Transmute);
// CHECK: [[BREF:_.+]] = copy (*_2);
// CHECK: [[BINT:_.+]] = copy (*[[BREF]]) as u64 (Transmute);
// CHECK: _0 = Eq(move [[AINT]], move [[BINT]]);
mir! {
{
Call(RET = raw_eq(*a, *b), ReturnTo(ret), UnwindUnreachable())
}
ret = {
Return()
}
}
}
@@ -4,6 +4,8 @@ fn eq_ipv4(_1: &[u8; 4], _2: &[u8; 4]) -> bool {
debug a => _1;
debug b => _2;
let mut _0: bool;
let mut _3: u32;
let mut _4: u32;
scope 1 (inlined std::cmp::impls::<impl PartialEq for &[u8; 4]>::eq) {
scope 2 (inlined array::equality::<impl PartialEq for [u8; 4]>::eq) {
scope 3 (inlined <u8 as array::equality::SpecArrayEq<u8, 4>>::spec_eq) {
@@ -12,10 +14,9 @@ fn eq_ipv4(_1: &[u8; 4], _2: &[u8; 4]) -> bool {
}
bb0: {
_0 = raw_eq::<[u8; 4]>(move _1, move _2) -> [return: bb1, unwind unreachable];
}
bb1: {
_3 = copy (*_1) as u32 (Transmute);
_4 = copy (*_2) as u32 (Transmute);
_0 = Eq(move _3, move _4);
return;
}
}
@@ -4,6 +4,8 @@ fn eq_ipv4(_1: &[u8; 4], _2: &[u8; 4]) -> bool {
debug a => _1;
debug b => _2;
let mut _0: bool;
let mut _3: u32;
let mut _4: u32;
scope 1 (inlined std::cmp::impls::<impl PartialEq for &[u8; 4]>::eq) {
scope 2 (inlined array::equality::<impl PartialEq for [u8; 4]>::eq) {
scope 3 (inlined <u8 as array::equality::SpecArrayEq<u8, 4>>::spec_eq) {
@@ -12,10 +14,9 @@ fn eq_ipv4(_1: &[u8; 4], _2: &[u8; 4]) -> bool {
}
bb0: {
_0 = raw_eq::<[u8; 4]>(move _1, move _2) -> [return: bb1, unwind unreachable];
}
bb1: {
_3 = copy (*_1) as u32 (Transmute);
_4 = copy (*_2) as u32 (Transmute);
_0 = Eq(move _3, move _4);
return;
}
}
@@ -4,6 +4,8 @@ fn eq_ipv6(_1: &[u16; 8], _2: &[u16; 8]) -> bool {
debug a => _1;
debug b => _2;
let mut _0: bool;
let mut _3: u128;
let mut _4: u128;
scope 1 (inlined std::cmp::impls::<impl PartialEq for &[u16; 8]>::eq) {
scope 2 (inlined array::equality::<impl PartialEq for [u16; 8]>::eq) {
scope 3 (inlined <u16 as array::equality::SpecArrayEq<u16, 8>>::spec_eq) {
@@ -12,10 +14,9 @@ fn eq_ipv6(_1: &[u16; 8], _2: &[u16; 8]) -> bool {
}
bb0: {
_0 = raw_eq::<[u16; 8]>(move _1, move _2) -> [return: bb1, unwind unreachable];
}
bb1: {
_3 = copy (*_1) as u128 (Transmute);
_4 = copy (*_2) as u128 (Transmute);
_0 = Eq(move _3, move _4);
return;
}
}
@@ -4,6 +4,8 @@ fn eq_ipv6(_1: &[u16; 8], _2: &[u16; 8]) -> bool {
debug a => _1;
debug b => _2;
let mut _0: bool;
let mut _3: u128;
let mut _4: u128;
scope 1 (inlined std::cmp::impls::<impl PartialEq for &[u16; 8]>::eq) {
scope 2 (inlined array::equality::<impl PartialEq for [u16; 8]>::eq) {
scope 3 (inlined <u16 as array::equality::SpecArrayEq<u16, 8>>::spec_eq) {
@@ -12,10 +14,9 @@ fn eq_ipv6(_1: &[u16; 8], _2: &[u16; 8]) -> bool {
}
bb0: {
_0 = raw_eq::<[u16; 8]>(move _1, move _2) -> [return: bb1, unwind unreachable];
}
bb1: {
_3 = copy (*_1) as u128 (Transmute);
_4 = copy (*_2) as u128 (Transmute);
_0 = Eq(move _3, move _4);
return;
}
}
+6 -2
View File
@@ -13,13 +13,17 @@
// EMIT_MIR array_eq.eq_ipv4.PreCodegen.after.mir
pub unsafe fn eq_ipv4<T: Copy>(a: &[u8; 4], b: &[u8; 4]) -> bool {
// CHECK-LABEL: fn eq_ipv4(_1: &[u8; 4], _2: &[u8; 4]) -> bool
// CHECK: _0 = raw_eq::<[u8; 4]>(move _1, move _2)
// CHECK: [[A:_.+]] = copy (*_1) as u32 (Transmute);
// CHECK: [[B:_.+]] = copy (*_2) as u32 (Transmute);
// CHECK: _0 = Eq(move [[A]], move [[B]]);
a == b
}
// EMIT_MIR array_eq.eq_ipv6.PreCodegen.after.mir
pub unsafe fn eq_ipv6<T: Copy>(a: &[u16; 8], b: &[u16; 8]) -> bool {
// CHECK-LABEL: fn eq_ipv6(_1: &[u16; 8], _2: &[u16; 8]) -> bool
// CHECK: _0 = raw_eq::<[u16; 8]>(move _1, move _2)
// CHECK: [[A:_.+]] = copy (*_1) as u128 (Transmute);
// CHECK: [[B:_.+]] = copy (*_2) as u128 (Transmute);
// CHECK: _0 = Eq(move [[A]], move [[B]]);
a == b
}