From 3d89a5be5099cbdf4a9dd3c40d198b573246c0ae Mon Sep 17 00:00:00 2001 From: sayantn Date: Tue, 25 Nov 2025 22:39:43 +0530 Subject: [PATCH] Add autocasts for structs --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 85 +++++++++++++++++--- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 + compiler/rustc_codegen_llvm/src/type_.rs | 10 +++ tests/codegen-llvm/inject-autocast.rs | 39 +++++++++ 4 files changed, 127 insertions(+), 10 deletions(-) create mode 100644 tests/codegen-llvm/inject-autocast.rs diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 4c66c4ef8bdd..6aedb6d97d0f 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -820,7 +820,7 @@ fn codegen_llvm_intrinsic_call( &mut self, instance: ty::Instance<'tcx>, args: &[OperandRef<'tcx, Self::Value>], - is_cleanup: bool, + _is_cleanup: bool, ) -> Self::Value { let tcx = self.tcx(); @@ -871,7 +871,7 @@ fn codegen_llvm_intrinsic_call( for arg in args { match arg.val { OperandValue::ZeroSized => {} - OperandValue::Immediate(_) => llargs.push(arg.immediate()), + OperandValue::Immediate(a) => llargs.push(a), OperandValue::Pair(a, b) => { llargs.push(a); llargs.push(b); @@ -897,24 +897,38 @@ fn codegen_llvm_intrinsic_call( } debug!("call intrinsic {:?} with args ({:?})", instance, llargs); - let args = self.check_call("call", fn_ty, fn_ptr, &llargs); + + for (dest_ty, arg) in iter::zip(self.func_params_types(fn_ty), &mut llargs) { + let src_ty = self.val_ty(arg); + assert!( + can_autocast(self, src_ty, dest_ty), + "Cannot match `{dest_ty:?}` (expected) with {src_ty:?} (found) in `{fn_ptr:?}" + ); + + *arg = autocast(self, arg, src_ty, dest_ty); + } + let llret = unsafe { llvm::LLVMBuildCallWithOperandBundles( self.llbuilder, fn_ty, fn_ptr, - args.as_ptr() as *const &llvm::Value, - args.len() as c_uint, + llargs.as_ptr(), + llargs.len() as c_uint, ptr::dangling(), 0, c"".as_ptr(), ) }; - if is_cleanup { - self.apply_attrs_to_cleanup_callsite(llret); - } - llret + let src_ty = self.val_ty(llret); + let dest_ty = llreturn_ty; + assert!( + can_autocast(self, dest_ty, src_ty), + "Cannot match `{src_ty:?}` (expected) with `{dest_ty:?}` (found) in `{fn_ptr:?}`" + ); + + autocast(self, llret, src_ty, dest_ty) } fn abort(&mut self) { @@ -985,6 +999,57 @@ fn llvm_arch_for(rust_arch: &Arch) -> Option<&'static str> { }) } +fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool { + if rust_ty == llvm_ty { + return true; + } + + // Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust + // due to auto field-alignment in non-packed structs (packed structs are represented in LLVM + // as, well, packed structs, so they won't match with those either) + if cx.type_kind(llvm_ty) == TypeKind::Struct && cx.type_kind(rust_ty) == TypeKind::Struct { + let rust_element_tys = cx.struct_element_types(rust_ty); + let llvm_element_tys = cx.struct_element_types(llvm_ty); + + if rust_element_tys.len() != llvm_element_tys.len() { + return false; + } + + iter::zip(rust_element_tys, llvm_element_tys).all(|(rust_element_ty, llvm_element_ty)| { + can_autocast(cx, rust_element_ty, llvm_element_ty) + }) + } else { + false + } +} + +fn autocast<'ll>( + bx: &mut Builder<'_, 'll, '_>, + val: &'ll Value, + src_ty: &'ll Type, + dest_ty: &'ll Type, +) -> &'ll Value { + if src_ty == dest_ty { + return val; + } + match (bx.type_kind(src_ty), bx.type_kind(dest_ty)) { + // re-pack structs + (TypeKind::Struct, TypeKind::Struct) => { + let mut ret = bx.const_poison(dest_ty); + for (idx, (src_element_ty, dest_element_ty)) in + iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty)) + .enumerate() + { + let elt = bx.extract_value(val, idx as u64); + let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty); + ret = bx.insert_value(ret, casted_elt, idx as u64); + } + ret + } + _ => unreachable!(), + } +} + fn intrinsic_fn<'ll, 'tcx>( bx: &Builder<'_, 'll, 'tcx>, name: &str, @@ -1030,7 +1095,7 @@ fn intrinsic_fn<'ll, 'tcx>( && rust_argument_tys.len() == llvm_argument_tys.len() && iter::once((rust_return_ty, llvm_return_ty)) .chain(iter::zip(rust_argument_tys, llvm_argument_tys)) - .all(|(rust_ty, llvm_ty)| rust_ty == llvm_ty); + .all(|(rust_ty, llvm_ty)| can_autocast(bx, rust_ty, llvm_ty)); if !is_correct_signature { tcx.dcx().emit_fatal(IntrinsicSignatureMismatch { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 7edbaf5a5f33..deafa38b7be6 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1613,6 +1613,9 @@ pub(crate) fn LLVMStructSetBody<'a>( Packed: Bool, ); + pub(crate) fn LLVMCountStructElementTypes(StructTy: &Type) -> c_uint; + pub(crate) fn LLVMGetStructElementTypes<'a>(StructTy: &'a Type, Dest: *mut &'a Type); + pub(crate) safe fn LLVMMetadataAsValue<'a>(C: &'a Context, MD: &'a Metadata) -> &'a Value; pub(crate) safe fn LLVMSetUnnamedAddress(Global: &Value, UnnamedAddr: UnnamedAddr); diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index b8cee3510789..147056a5885a 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -94,6 +94,16 @@ pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> { pub(crate) fn func_is_variadic(&self, ty: &'ll Type) -> bool { unsafe { llvm::LLVMIsFunctionVarArg(ty).is_true() } } + + pub(crate) fn struct_element_types(&self, ty: &'ll Type) -> Vec<&'ll Type> { + unsafe { + let n_args = llvm::LLVMCountStructElementTypes(ty) as usize; + let mut args = Vec::with_capacity(n_args); + llvm::LLVMGetStructElementTypes(ty, args.as_mut_ptr()); + args.set_len(n_args); + args + } + } } impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { pub(crate) fn type_bool(&self) -> &'ll Type { diff --git a/tests/codegen-llvm/inject-autocast.rs b/tests/codegen-llvm/inject-autocast.rs new file mode 100644 index 000000000000..d79779285889 --- /dev/null +++ b/tests/codegen-llvm/inject-autocast.rs @@ -0,0 +1,39 @@ +//@ compile-flags: -C opt-level=0 -C target-feature=+kl +//@ only-x86_64 + +#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)] +#![crate_type = "lib"] + +use std::simd::i64x2; + +#[repr(C, packed)] +pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2); +// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }> + +// CHECK-LABEL: @struct_autocast +#[no_mangle] +pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar { + extern "unadjusted" { + #[link_name = "llvm.x86.encodekey128"] + fn foo(key_metadata: u32, key: i64x2) -> Bar; + } + + // CHECK: [[A:%[0-9]+]] = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 {{.*}}, <2 x i64> {{.*}}) + // CHECK: [[B:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 0 + // CHECK: [[C:%[0-9]+]] = insertvalue %Bar poison, i32 [[B]], 0 + // CHECK: [[D:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 1 + // CHECK: [[E:%[0-9]+]] = insertvalue %Bar [[C]], <2 x i64> [[D]], 1 + // CHECK: [[F:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 2 + // CHECK: [[G:%[0-9]+]] = insertvalue %Bar [[E]], <2 x i64> [[F]], 2 + // CHECK: [[H:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 3 + // CHECK: [[I:%[0-9]+]] = insertvalue %Bar [[G]], <2 x i64> [[H]], 3 + // CHECK: [[J:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 4 + // CHECK: [[K:%[0-9]+]] = insertvalue %Bar [[I]], <2 x i64> [[J]], 4 + // CHECK: [[L:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 5 + // CHECK: [[M:%[0-9]+]] = insertvalue %Bar [[K]], <2 x i64> [[L]], 5 + // CHECK: [[N:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 6 + // CHECK: insertvalue %Bar [[M]], <2 x i64> [[N]], 6 + foo(key_metadata, key) +} + +// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)