mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Add autocast for bf16 and bf16xN
This commit is contained in:
@@ -1022,12 +1022,20 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
|
||||
},
|
||||
)
|
||||
}
|
||||
TypeKind::Vector if cx.element_type(llvm_ty) == cx.type_i1() => {
|
||||
TypeKind::Vector => {
|
||||
let llvm_element_ty = cx.element_type(llvm_ty);
|
||||
let element_count = cx.vector_length(llvm_ty) as u64;
|
||||
let int_width = element_count.next_power_of_two().max(8);
|
||||
|
||||
rust_ty == cx.type_ix(int_width)
|
||||
if llvm_element_ty == cx.type_bf16() {
|
||||
rust_ty == cx.type_vector(cx.type_i16(), element_count)
|
||||
} else if llvm_element_ty == cx.type_i1() {
|
||||
let int_width = element_count.next_power_of_two().max(8);
|
||||
rust_ty == cx.type_ix(int_width)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
TypeKind::BFloat => rust_ty == cx.type_i16(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -1097,7 +1105,7 @@ fn autocast<'ll>(
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -921,6 +921,9 @@ pub(crate) fn LLVMGetInlineAsm<'ll>(
|
||||
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
|
||||
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
|
||||
|
||||
// Operations on non-IEEE real types
|
||||
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
|
||||
|
||||
// Operations on function types
|
||||
pub(crate) fn LLVMFunctionType<'a>(
|
||||
ReturnType: &'a Type,
|
||||
|
||||
@@ -183,6 +183,10 @@ pub(crate) fn type_struct(&self, els: &[&'ll Type], packed: bool) -> &'ll Type {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn type_bf16(&self) -> &'ll Type {
|
||||
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl
|
||||
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
|
||||
//@ only-x86_64
|
||||
|
||||
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
|
||||
#![crate_type = "lib"]
|
||||
|
||||
use std::simd::i64x2;
|
||||
use std::simd::{f32x4, i16x8, i64x2};
|
||||
|
||||
#[repr(C, packed)]
|
||||
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
|
||||
@@ -71,8 +71,23 @@ pub unsafe fn i1_vector_autocast(a: u8, b: u8) -> u8 {
|
||||
foo(a, b)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @bf16_vector_autocast
|
||||
#[no_mangle]
|
||||
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
|
||||
extern "unadjusted" {
|
||||
#[link_name = "llvm.x86.vcvtneps2bf16128"]
|
||||
fn foo(a: f32x4) -> i16x8;
|
||||
}
|
||||
|
||||
// CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}})
|
||||
// CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16>
|
||||
foo(a)
|
||||
}
|
||||
|
||||
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
|
||||
|
||||
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
|
||||
|
||||
// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)
|
||||
|
||||
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
|
||||
|
||||
Reference in New Issue
Block a user