Add autocasts for structs

This commit is contained in:
sayantn
2025-11-25 22:39:43 +05:30
parent a5372be2a1
commit 3d89a5be50
4 changed files with 127 additions and 10 deletions
+75 -10
View File
@@ -820,7 +820,7 @@ fn codegen_llvm_intrinsic_call(
&mut self,
instance: ty::Instance<'tcx>,
args: &[OperandRef<'tcx, Self::Value>],
is_cleanup: bool,
_is_cleanup: bool,
) -> Self::Value {
let tcx = self.tcx();
@@ -871,7 +871,7 @@ fn codegen_llvm_intrinsic_call(
for arg in args {
match arg.val {
OperandValue::ZeroSized => {}
OperandValue::Immediate(_) => llargs.push(arg.immediate()),
OperandValue::Immediate(a) => llargs.push(a),
OperandValue::Pair(a, b) => {
llargs.push(a);
llargs.push(b);
@@ -897,24 +897,38 @@ fn codegen_llvm_intrinsic_call(
}
debug!("call intrinsic {:?} with args ({:?})", instance, llargs);
let args = self.check_call("call", fn_ty, fn_ptr, &llargs);
for (dest_ty, arg) in iter::zip(self.func_params_types(fn_ty), &mut llargs) {
let src_ty = self.val_ty(arg);
assert!(
can_autocast(self, src_ty, dest_ty),
"Cannot match `{dest_ty:?}` (expected) with {src_ty:?} (found) in `{fn_ptr:?}"
);
*arg = autocast(self, arg, src_ty, dest_ty);
}
let llret = unsafe {
llvm::LLVMBuildCallWithOperandBundles(
self.llbuilder,
fn_ty,
fn_ptr,
args.as_ptr() as *const &llvm::Value,
args.len() as c_uint,
llargs.as_ptr(),
llargs.len() as c_uint,
ptr::dangling(),
0,
c"".as_ptr(),
)
};
if is_cleanup {
self.apply_attrs_to_cleanup_callsite(llret);
}
llret
let src_ty = self.val_ty(llret);
let dest_ty = llreturn_ty;
assert!(
can_autocast(self, dest_ty, src_ty),
"Cannot match `{src_ty:?}` (expected) with `{dest_ty:?}` (found) in `{fn_ptr:?}`"
);
autocast(self, llret, src_ty, dest_ty)
}
fn abort(&mut self) {
@@ -985,6 +999,57 @@ fn llvm_arch_for(rust_arch: &Arch) -> Option<&'static str> {
})
}
fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
if rust_ty == llvm_ty {
return true;
}
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
// as, well, packed structs, so they won't match with those either)
if cx.type_kind(llvm_ty) == TypeKind::Struct && cx.type_kind(rust_ty) == TypeKind::Struct {
let rust_element_tys = cx.struct_element_types(rust_ty);
let llvm_element_tys = cx.struct_element_types(llvm_ty);
if rust_element_tys.len() != llvm_element_tys.len() {
return false;
}
iter::zip(rust_element_tys, llvm_element_tys).all(|(rust_element_ty, llvm_element_ty)| {
can_autocast(cx, rust_element_ty, llvm_element_ty)
})
} else {
false
}
}
fn autocast<'ll>(
bx: &mut Builder<'_, 'll, '_>,
val: &'ll Value,
src_ty: &'ll Type,
dest_ty: &'ll Type,
) -> &'ll Value {
if src_ty == dest_ty {
return val;
}
match (bx.type_kind(src_ty), bx.type_kind(dest_ty)) {
// re-pack structs
(TypeKind::Struct, TypeKind::Struct) => {
let mut ret = bx.const_poison(dest_ty);
for (idx, (src_element_ty, dest_element_ty)) in
iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty))
.enumerate()
{
let elt = bx.extract_value(val, idx as u64);
let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty);
ret = bx.insert_value(ret, casted_elt, idx as u64);
}
ret
}
_ => unreachable!(),
}
}
fn intrinsic_fn<'ll, 'tcx>(
bx: &Builder<'_, 'll, 'tcx>,
name: &str,
@@ -1030,7 +1095,7 @@ fn intrinsic_fn<'ll, 'tcx>(
&& rust_argument_tys.len() == llvm_argument_tys.len()
&& iter::once((rust_return_ty, llvm_return_ty))
.chain(iter::zip(rust_argument_tys, llvm_argument_tys))
.all(|(rust_ty, llvm_ty)| rust_ty == llvm_ty);
.all(|(rust_ty, llvm_ty)| can_autocast(bx, rust_ty, llvm_ty));
if !is_correct_signature {
tcx.dcx().emit_fatal(IntrinsicSignatureMismatch {
@@ -1613,6 +1613,9 @@ pub(crate) fn LLVMStructSetBody<'a>(
Packed: Bool,
);
pub(crate) fn LLVMCountStructElementTypes(StructTy: &Type) -> c_uint;
pub(crate) fn LLVMGetStructElementTypes<'a>(StructTy: &'a Type, Dest: *mut &'a Type);
pub(crate) safe fn LLVMMetadataAsValue<'a>(C: &'a Context, MD: &'a Metadata) -> &'a Value;
pub(crate) safe fn LLVMSetUnnamedAddress(Global: &Value, UnnamedAddr: UnnamedAddr);
+10
View File
@@ -94,6 +94,16 @@ pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
pub(crate) fn func_is_variadic(&self, ty: &'ll Type) -> bool {
unsafe { llvm::LLVMIsFunctionVarArg(ty).is_true() }
}
pub(crate) fn struct_element_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
unsafe {
let n_args = llvm::LLVMCountStructElementTypes(ty) as usize;
let mut args = Vec::with_capacity(n_args);
llvm::LLVMGetStructElementTypes(ty, args.as_mut_ptr());
args.set_len(n_args);
args
}
}
}
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
pub(crate) fn type_bool(&self) -> &'ll Type {
+39
View File
@@ -0,0 +1,39 @@
//@ compile-flags: -C opt-level=0 -C target-feature=+kl
//@ only-x86_64
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
#![crate_type = "lib"]
use std::simd::i64x2;
#[repr(C, packed)]
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
// CHECK-LABEL: @struct_autocast
#[no_mangle]
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
extern "unadjusted" {
#[link_name = "llvm.x86.encodekey128"]
fn foo(key_metadata: u32, key: i64x2) -> Bar;
}
// CHECK: [[A:%[0-9]+]] = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 {{.*}}, <2 x i64> {{.*}})
// CHECK: [[B:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 0
// CHECK: [[C:%[0-9]+]] = insertvalue %Bar poison, i32 [[B]], 0
// CHECK: [[D:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 1
// CHECK: [[E:%[0-9]+]] = insertvalue %Bar [[C]], <2 x i64> [[D]], 1
// CHECK: [[F:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 2
// CHECK: [[G:%[0-9]+]] = insertvalue %Bar [[E]], <2 x i64> [[F]], 2
// CHECK: [[H:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 3
// CHECK: [[I:%[0-9]+]] = insertvalue %Bar [[G]], <2 x i64> [[H]], 3
// CHECK: [[J:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 4
// CHECK: [[K:%[0-9]+]] = insertvalue %Bar [[I]], <2 x i64> [[J]], 4
// CHECK: [[L:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 5
// CHECK: [[M:%[0-9]+]] = insertvalue %Bar [[K]], <2 x i64> [[L]], 5
// CHECK: [[N:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 6
// CHECK: insertvalue %Bar [[M]], <2 x i64> [[N]], 6
foo(key_metadata, key)
}
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)