Add autocast for bf16 and bf16xN

This commit is contained in:
sayantn
2025-11-25 23:28:37 +05:30
parent 5aa800af80
commit 11f350da38
4 changed files with 36 additions and 6 deletions
+12 -4
View File
@@ -1022,12 +1022,20 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
},
)
}
TypeKind::Vector if cx.element_type(llvm_ty) == cx.type_i1() => {
TypeKind::Vector => {
let llvm_element_ty = cx.element_type(llvm_ty);
let element_count = cx.vector_length(llvm_ty) as u64;
let int_width = element_count.next_power_of_two().max(8);
rust_ty == cx.type_ix(int_width)
if llvm_element_ty == cx.type_bf16() {
rust_ty == cx.type_vector(cx.type_i16(), element_count)
} else if llvm_element_ty == cx.type_i1() {
let int_width = element_count.next_power_of_two().max(8);
rust_ty == cx.type_ix(int_width)
} else {
false
}
}
TypeKind::BFloat => rust_ty == cx.type_i16(),
_ => false,
}
}
@@ -1097,7 +1105,7 @@ fn autocast<'ll>(
)
}
}
_ => unreachable!(),
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
}
}
@@ -921,6 +921,9 @@ pub(crate) fn LLVMGetInlineAsm<'ll>(
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
// Operations on non-IEEE real types
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
// Operations on function types
pub(crate) fn LLVMFunctionType<'a>(
ReturnType: &'a Type,
+4
View File
@@ -183,6 +183,10 @@ pub(crate) fn type_struct(&self, els: &[&'ll Type], packed: bool) -> &'ll Type {
)
}
}
pub(crate) fn type_bf16(&self) -> &'ll Type {
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
}
}
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
+17 -2
View File
@@ -1,10 +1,10 @@
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
//@ only-x86_64
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
#![crate_type = "lib"]
use std::simd::i64x2;
use std::simd::{f32x4, i16x8, i64x2};
#[repr(C, packed)]
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
@@ -71,8 +71,23 @@ pub unsafe fn i1_vector_autocast(a: u8, b: u8) -> u8 {
foo(a, b)
}
// CHECK-LABEL: @bf16_vector_autocast
#[no_mangle]
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
extern "unadjusted" {
#[link_name = "llvm.x86.vcvtneps2bf16128"]
fn foo(a: f32x4) -> i16x8;
}
// CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}})
// CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16>
foo(a)
}
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)