preseve SIMD element type information

and provide it to LLVM for better optimization
This commit is contained in:
Folkert de Vries
2026-04-08 21:03:50 +02:00
parent 14196dbfa3
commit 6f428df8df
16 changed files with 212 additions and 27 deletions
+3 -2
View File
@@ -75,10 +75,11 @@ pub fn homogeneous_aggregate<C>(&self, cx: &C) -> Result<HomogeneousAggregate, H
Ok(HomogeneousAggregate::Homogeneous(Reg { kind, size: self.size }))
}
BackendRepr::SimdVector { .. } => {
BackendRepr::SimdVector { element, count: _ } => {
assert!(!self.is_zst());
Ok(HomogeneousAggregate::Homogeneous(Reg {
kind: RegKind::Vector,
kind: RegKind::Vector { hint_vector_elem: element.primitive() },
size: self.size,
}))
}
+14 -3
View File
@@ -1,14 +1,19 @@
#[cfg(feature = "nightly")]
use rustc_macros::HashStable_Generic;
use crate::{Align, HasDataLayout, Size};
use crate::{Align, HasDataLayout, Integer, Primitive, Size};
#[cfg_attr(feature = "nightly", derive(HashStable_Generic))]
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum RegKind {
Integer,
Float,
Vector,
Vector {
/// The `hint_vector_elem` is strictly for optimization purposes. E.g. it can be used by
/// a codegen backend to prevent extra bitcasts that obscure a pattern. Alternatively,
/// it can be safely ignored by always picking i8.
hint_vector_elem: Primitive,
},
}
#[cfg_attr(feature = "nightly", derive(HashStable_Generic))]
@@ -36,6 +41,12 @@ impl Reg {
reg_ctor!(f32, Float, 32);
reg_ctor!(f64, Float, 64);
reg_ctor!(f128, Float, 128);
/// A vector of the given size with an unknown (and irrelevant) element type.
pub fn opaque_vector(size: Size) -> Reg {
// Default to an i8 vector of the given size.
Reg { kind: RegKind::Vector { hint_vector_elem: Primitive::Int(Integer::I8, true) }, size }
}
}
impl Reg {
@@ -58,7 +69,7 @@ pub fn align<C: HasDataLayout>(&self, cx: &C) -> Align {
128 => dl.f128_align,
_ => panic!("unsupported float: {self:?}"),
},
RegKind::Vector => dl.llvmlike_vector_align(self.size),
RegKind::Vector { .. } => dl.llvmlike_vector_align(self.size),
}
}
}
@@ -26,7 +26,9 @@ fn reg_to_abi_param(reg: Reg) -> AbiParam {
(RegKind::Float, 4) => types::F32,
(RegKind::Float, 8) => types::F64,
(RegKind::Float, 16) => types::F128,
(RegKind::Vector, size) => types::I8.by(u32::try_from(size).unwrap()).unwrap(),
(RegKind::Vector { hint_vector_elem: _ }, size) => {
types::I8.by(u32::try_from(size).unwrap()).unwrap()
}
_ => unreachable!("{:?}", reg),
};
AbiParam::new(clif_ty)
+3 -1
View File
@@ -90,7 +90,9 @@ fn gcc_type<'gcc>(&self, cx: &CodegenCx<'gcc, '_>) -> Type<'gcc> {
64 => cx.type_f64(),
_ => bug!("unsupported float: {:?}", self),
},
RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()),
RegKind::Vector { hint_vector_elem: _ } => {
cx.type_vector(cx.type_i8(), self.size.bytes())
}
}
}
}
+27 -3
View File
@@ -2,8 +2,8 @@
use libc::c_uint;
use rustc_abi::{
ArmCall, BackendRepr, CanonAbi, HasDataLayout, InterruptKind, Primitive, Reg, RegKind, Size,
X86Call,
ArmCall, BackendRepr, CanonAbi, Float, HasDataLayout, Integer, InterruptKind, Primitive, Reg,
RegKind, Size, X86Call,
};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -137,7 +137,31 @@ fn llvm_type<'ll>(&self, cx: &CodegenCx<'ll, '_>) -> &'ll Type {
128 => cx.type_f128(),
_ => bug!("unsupported float: {:?}", self),
},
RegKind::Vector => cx.type_vector(cx.type_i8(), self.size.bytes()),
RegKind::Vector { hint_vector_elem } => {
// NOTE: it is valid to ignore the element type hint (and always pick i8).
// But providing a more accurate type means fewer casts in LLVM IR,
// which helps with optimization.
let ty = match hint_vector_elem {
Primitive::Int(integer, _) => match integer {
Integer::I8 => cx.type_ix(8),
Integer::I16 => cx.type_ix(16),
Integer::I32 => cx.type_ix(32),
Integer::I64 => cx.type_ix(64),
Integer::I128 => cx.type_ix(128),
},
Primitive::Float(float) => match float {
Float::F16 => cx.type_f16(),
Float::F32 => cx.type_f32(),
Float::F64 => cx.type_f64(),
Float::F128 => cx.type_f128(),
},
Primitive::Pointer(_) => cx.type_ptr(),
};
assert!(self.size.bytes().is_multiple_of(hint_vector_elem.size(cx).bytes()));
let len = self.size.bytes() / hint_vector_elem.size(cx).bytes();
cx.type_vector(ty, len)
}
}
}
}
@@ -421,7 +421,7 @@ fn wasm_type<'tcx>(signature: &mut String, arg_abi: &ArgAbi<'_, Ty<'tcx>>, ptr_t
..=8 => "f64",
_ => ptr_type,
},
RegKind::Vector => "v128",
RegKind::Vector { .. } => "v128",
};
signature.push_str(wrapped_wasm_type);
@@ -25,8 +25,11 @@ fn passes_vectors_by_value(mode: &PassMode, repr: &BackendRepr) -> UsesVectorReg
match mode {
PassMode::Ignore | PassMode::Indirect { .. } => UsesVectorRegisters::No,
PassMode::Cast { pad_i32: _, cast }
if cast.prefix.iter().any(|r| r.is_some_and(|x| x.kind == RegKind::Vector))
|| cast.rest.unit.kind == RegKind::Vector =>
if cast
.prefix
.iter()
.any(|r| r.is_some_and(|x| matches!(x.kind, RegKind::Vector { .. })))
|| matches!(cast.rest.unit.kind, RegKind::Vector { .. }) =>
{
UsesVectorRegisters::FixedVector
}
@@ -35,7 +35,7 @@ fn is_homogeneous_aggregate<'a, Ty, C>(cx: &C, arg: &mut ArgAbi<'a, Ty>) -> Opti
// The softfloat ABI treats floats like integers, so they
// do not get homogeneous aggregate treatment.
RegKind::Float => cx.target_spec().rustc_abi != Some(RustcAbi::Softfloat),
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128,
};
valid_unit.then_some(Uniform::consecutive(unit, size))
+1 -1
View File
@@ -19,7 +19,7 @@ fn is_homogeneous_aggregate<'a, Ty, C>(cx: &C, arg: &mut ArgAbi<'a, Ty>) -> Opti
let valid_unit = match unit.kind {
RegKind::Integer => false,
RegKind::Float => true,
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
RegKind::Vector { .. } => size.bits() == 64 || size.bits() == 128,
};
valid_unit.then_some(Uniform::consecutive(unit, size))
@@ -36,7 +36,7 @@ fn is_homogeneous_aggregate<'a, Ty, C>(
let valid_unit = match unit.kind {
RegKind::Integer => false,
RegKind::Float => true,
RegKind::Vector => arg.layout.size.bits() == 128,
RegKind::Vector { .. } => arg.layout.size.bits() == 128,
};
valid_unit.then_some(Uniform::consecutive(unit, arg.layout.size))
+2 -2
View File
@@ -3,7 +3,7 @@
use rustc_abi::{BackendRepr, HasDataLayout, TyAbiInterface};
use crate::callconv::{ArgAbi, FnAbi, Reg, RegKind};
use crate::callconv::{ArgAbi, FnAbi, Reg};
use crate::spec::{Env, HasTargetSpec, Os};
fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
@@ -51,7 +51,7 @@ fn classify_arg<'a, Ty, C>(cx: &C, arg: &mut ArgAbi<'a, Ty>)
if arg.layout.is_single_vector_element(cx, size) {
// pass non-transparent wrappers around a vector as `PassMode::Cast`
arg.cast_to(Reg { kind: RegKind::Vector, size });
arg.cast_to(Reg::opaque_vector(size));
return;
}
}
+4 -5
View File
@@ -1,9 +1,8 @@
use rustc_abi::{
AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAbiInterface,
TyAndLayout,
AddressSpace, Align, BackendRepr, HasDataLayout, Primitive, Reg, RegKind, TyAndLayout,
};
use crate::callconv::{ArgAttribute, FnAbi, PassMode};
use crate::callconv::{ArgAttribute, FnAbi, PassMode, TyAbiInterface};
use crate::spec::{HasTargetSpec, RustcAbi};
#[derive(PartialEq)]
@@ -175,7 +174,7 @@ pub(crate) fn fill_inregs<'a, Ty, C>(
// At this point we know this must be a primitive of sorts.
let unit = arg.layout.homogeneous_aggregate(cx).unwrap().unit().unwrap();
assert_eq!(unit.size, arg.layout.size);
if matches!(unit.kind, RegKind::Float | RegKind::Vector) {
if matches!(unit.kind, RegKind::Float | RegKind::Vector { .. }) {
continue;
}
@@ -226,7 +225,7 @@ pub(crate) fn compute_rust_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty
// This is a single scalar that fits into an SSE register, and the target uses the
// SSE ABI. We prefer this over integer registers as float scalars need to be in SSE
// registers for float operations, so that's the best place to pass them around.
fn_abi.ret.cast_to(Reg { kind: RegKind::Vector, size: fn_abi.ret.layout.size });
fn_abi.ret.cast_to(Reg::opaque_vector(fn_abi.ret.layout.size));
} else if fn_abi.ret.layout.size <= Primitive::Pointer(AddressSpace::ZERO).size(cx) {
// Same size or smaller than pointer, return in an integer register.
fn_abi.ret.cast_to(Reg { kind: RegKind::Integer, size: fn_abi.ret.layout.size });
+1 -1
View File
@@ -151,7 +151,7 @@ fn reg_component(cls: &[Option<Class>], i: &mut usize, size: Size) -> Option<Reg
_ => Reg::f64(),
}
} else {
Reg { kind: RegKind::Vector, size: Size::from_bytes(8) * (vec_len as u64) }
Reg::opaque_vector(Size::from_bytes(8) * (vec_len as u64))
})
}
Some(c) => unreachable!("reg_component: unhandled class {:?}", c),
@@ -1,4 +1,4 @@
use rustc_abi::{BackendRepr, Float, Integer, Primitive, RegKind, Size, TyAbiInterface};
use rustc_abi::{BackendRepr, Float, Integer, Primitive, Size, TyAbiInterface};
use crate::callconv::{ArgAbi, FnAbi, Reg};
use crate::spec::{HasTargetSpec, RustcAbi};
@@ -33,8 +33,7 @@ pub(crate) fn compute_abi_info<'a, Ty, C: HasTargetSpec>(cx: &C, fn_abi: &mut Fn
} else {
// `i128` is returned in xmm0 by Clang and GCC
// FIXME(#134288): This may change for the `-msvc` targets in the future.
let reg = Reg { kind: RegKind::Vector, size: Size::from_bits(128) };
a.cast_to(reg);
a.cast_to(Reg::opaque_vector(Size::from_bits(128)));
}
} else if a.layout.size.bytes() > 8
&& !matches!(scalar.primitive(), Primitive::Float(Float::F128))
+46
View File
@@ -0,0 +1,46 @@
//@ assembly-output: emit-asm
//@ compile-flags: -Copt-level=3
//@ only-aarch64-unknown-linux-gnu
#![feature(repr_simd, portable_simd, core_intrinsics, f16, f128)]
#![crate_type = "lib"]
#![allow(non_camel_case_types)]
// Test `vld_s16` can be implemented in a portable way (i.e. without using LLVM neon intrinsics).
// This relies on rust preserving the SIMD vector element type and using it to construct the
// LLVM type. Without this information, additional casts are needed that defeat the LLVM pattern
// matcher, see https://github.com/llvm/llvm-project/issues/181514.
use std::mem::transmute;
use std::simd::Simd;
#[unsafe(no_mangle)]
#[target_feature(enable = "neon")]
unsafe extern "C" fn vld2_s16_old(ptr: *const i16) -> std::arch::aarch64::int16x4x2_t {
// CHECK-LABEL: vld2_s16_old
// CHECK: .cfi_startproc
// CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0]
// CHECK-NEXT: ret
std::arch::aarch64::vld2_s16(ptr)
}
#[unsafe(no_mangle)]
#[target_feature(enable = "neon")]
unsafe extern "C" fn vld2_s16_new(a: *const i16) -> std::arch::aarch64::int16x4x2_t {
// CHECK-LABEL: vld2_s16_new
// CHECK: .cfi_startproc
// CHECK-NEXT: ld2 { v0.4h, v1.4h }, [x0]
// CHECK-NEXT: ret
type V = Simd<i16, 4>;
type W = Simd<i16, 8>;
let w: W = std::ptr::read_unaligned(a as *const W);
#[repr(simd)]
pub(crate) struct SimdShuffleIdx<const LEN: usize>([u32; LEN]);
let v0: V = std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([0, 2, 4, 6]) });
let v1: V = std::intrinsics::simd::simd_shuffle(w, w, const { SimdShuffleIdx([1, 3, 5, 7]) });
transmute((v0, v1))
}
@@ -0,0 +1,98 @@
// ignore-tidy-linelength
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled --target=aarch64-unknown-linux-gnu
//@ needs-llvm-components: aarch64
//@ add-minicore
#![feature(no_core, repr_simd, f16, f128)]
#![crate_type = "lib"]
#![no_std]
#![no_core]
#![allow(non_camel_case_types)]
// Test that the SIMD vector element type is preserved. This is not required for correctness, but
// useful for optimization. It prevents additional bitcasts that make LLVM patterns fail.
extern crate minicore;
use minicore::*;
#[repr(simd)]
pub struct Simd<T, const N: usize>([T; N]);
#[repr(C)]
struct Pair<T>(T, T);
#[repr(C)]
struct Triple<T>(T, T, T);
#[repr(C)]
struct Quad<T>(T, T, T, T);
#[rustfmt::skip]
mod tests {
use super::*;
// CHECK: define [2 x <8 x i8>] @pair_int8x8_t([2 x <8 x i8>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int8x8_t(x: Pair<Simd<i8, 8>>) -> Pair<Simd<i8, 8>> { x }
// CHECK: define [2 x <4 x i16>] @pair_int16x4_t([2 x <4 x i16>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int16x4_t(x: Pair<Simd<i16, 4>>) -> Pair<Simd<i16, 4>> { x }
// CHECK: define [2 x <2 x i32>] @pair_int32x2_t([2 x <2 x i32>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int32x2_t(x: Pair<Simd<i32, 2>>) -> Pair<Simd<i32, 2>> { x }
// CHECK: define [2 x <1 x i64>] @pair_int64x1_t([2 x <1 x i64>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_int64x1_t(x: Pair<Simd<i64, 1>>) -> Pair<Simd<i64, 1>> { x }
// CHECK: define [2 x <4 x half>] @pair_float16x4_t([2 x <4 x half>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float16x4_t(x: Pair<Simd<f16, 4>>) -> Pair<Simd<f16, 4>> { x }
// CHECK: define [2 x <2 x float>] @pair_float32x2_t([2 x <2 x float>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float32x2_t(x: Pair<Simd<f32, 2>>) -> Pair<Simd<f32, 2>> { x }
// CHECK: define [2 x <1 x double>] @pair_float64x1_t([2 x <1 x double>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_float64x1_t(x: Pair<Simd<f64, 1>>) -> Pair<Simd<f64, 1>> { x }
// CHECK: define [2 x <1 x ptr>] @pair_ptrx1_t([2 x <1 x ptr>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn pair_ptrx1_t(x: Pair<Simd<*const (), 1>>) -> Pair<Simd<*const (), 1>> { x }
// When it fits in a 128-bit register, it's passed directly.
// CHECK: define [4 x <4 x i8>] @quad_int8x4_t([4 x <4 x i8>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int8x4_t(x: Quad<Simd<i8, 4>>) -> Quad<Simd<i8, 4>> { x }
// CHECK: define [4 x <2 x i16>] @quad_int16x2_t([4 x <2 x i16>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int16x2_t(x: Quad<Simd<i16, 2>>) -> Quad<Simd<i16, 2>> { x }
// CHECK: define [4 x <1 x i32>] @quad_int32x1_t([4 x <1 x i32>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_int32x1_t(x: Quad<Simd<i32, 1>>) -> Quad<Simd<i32, 1>> { x }
// CHECK: define [4 x <2 x half>] @quad_float16x2_t([4 x <2 x half>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_float16x2_t(x: Quad<Simd<f16, 2>>) -> Quad<Simd<f16, 2>> { x }
// CHECK: define [4 x <1 x float>] @quad_float32x1_t([4 x <1 x float>] {{.*}} %0)
#[unsafe(no_mangle)] extern "C" fn quad_float32x1_t(x: Quad<Simd<f32, 1>>) -> Quad<Simd<f32, 1>> { x }
// When it doesn't quite fit, padding is added which does erase the type.
// CHECK: define [2 x i64] @triple_int8x4_t
#[unsafe(no_mangle)] extern "C" fn triple_int8x4_t(x: Triple<Simd<i8, 4>>) -> Triple<Simd<i8, 4>> { x }
// Other configurations are not passed by-value but indirectly.
// CHECK: define void @pair_int128x1_t
#[unsafe(no_mangle)] extern "C" fn pair_int128x1_t(x: Pair<Simd<i128, 1>>) -> Pair<Simd<i128, 1>> { x }
// CHECK: define void @pair_float128x1_t
#[unsafe(no_mangle)] extern "C" fn pair_float128x1_t(x: Pair<Simd<f128, 1>>) -> Pair<Simd<f128, 1>> { x }
// CHECK: define void @pair_int8x16_t
#[unsafe(no_mangle)] extern "C" fn pair_int8x16_t(x: Pair<Simd<i8, 16>>) -> Pair<Simd<i8, 16>> { x }
// CHECK: define void @pair_int16x8_t
#[unsafe(no_mangle)] extern "C" fn pair_int16x8_t(x: Pair<Simd<i16, 8>>) -> Pair<Simd<i16, 8>> { x }
// CHECK: define void @triple_int16x8_t
#[unsafe(no_mangle)] extern "C" fn triple_int16x8_t(x: Triple<Simd<i16, 8>>) -> Triple<Simd<i16, 8>> { x }
// CHECK: define void @quad_int16x8_t
#[unsafe(no_mangle)] extern "C" fn quad_int16x8_t(x: Quad<Simd<i16, 8>>) -> Quad<Simd<i16, 8>> { x }
}