diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 70f145a7155b..d46672bdffb7 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1022,12 +1022,20 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll }, ) } - TypeKind::Vector if cx.element_type(llvm_ty) == cx.type_i1() => { + TypeKind::Vector => { + let llvm_element_ty = cx.element_type(llvm_ty); let element_count = cx.vector_length(llvm_ty) as u64; - let int_width = element_count.next_power_of_two().max(8); - rust_ty == cx.type_ix(int_width) + if llvm_element_ty == cx.type_bf16() { + rust_ty == cx.type_vector(cx.type_i16(), element_count) + } else if llvm_element_ty == cx.type_i1() { + let int_width = element_count.next_power_of_two().max(8); + rust_ty == cx.type_ix(int_width) + } else { + false + } } + TypeKind::BFloat => rust_ty == cx.type_i16(), _ => false, } } @@ -1097,7 +1105,7 @@ fn autocast<'ll>( ) } } - _ => unreachable!(), + _ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)` } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index deafa38b7be6..525d1dbe9d0d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -921,6 +921,9 @@ pub(crate) fn LLVMGetInlineAsm<'ll>( pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type; pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type; + // Operations on non-IEEE real types + pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type; + // Operations on function types pub(crate) fn LLVMFunctionType<'a>( ReturnType: &'a Type, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index 147056a5885a..796f3d9ef60b 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -183,6 +183,10 @@ pub(crate) fn type_struct(&self, els: &[&'ll Type], packed: bool) -> &'ll Type { ) } } + + pub(crate) fn type_bf16(&self) -> &'ll Type { + unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) } + } } impl<'ll, CX: Borrow>> BaseTypeCodegenMethods for GenericCx<'ll, CX> { diff --git a/tests/codegen-llvm/inject-autocast.rs b/tests/codegen-llvm/inject-autocast.rs index ae5bd0e42299..fec9d3f0b195 100644 --- a/tests/codegen-llvm/inject-autocast.rs +++ b/tests/codegen-llvm/inject-autocast.rs @@ -1,10 +1,10 @@ -//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl +//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert //@ only-x86_64 #![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)] #![crate_type = "lib"] -use std::simd::i64x2; +use std::simd::{f32x4, i16x8, i64x2}; #[repr(C, packed)] pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2); @@ -71,8 +71,23 @@ pub unsafe fn i1_vector_autocast(a: u8, b: u8) -> u8 { foo(a, b) } +// CHECK-LABEL: @bf16_vector_autocast +#[no_mangle] +pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 { + extern "unadjusted" { + #[link_name = "llvm.x86.vcvtneps2bf16128"] + fn foo(a: f32x4) -> i16x8; + } + + // CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}}) + // CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16> + foo(a) +} + // CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>) // CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>) // CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>) + +// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)