Add autocast for x86_amx

This commit is contained in:
sayantn
2026-04-14 02:10:36 +05:30
parent 1a28bc4529
commit 7e24cd823d
2 changed files with 51 additions and 2 deletions
@@ -1015,6 +1015,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
}
}
TypeKind::BFloat => rust_ty == cx.type_i16(),
TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => {
let element_ty = cx.element_type(rust_ty);
let element_count = cx.vector_length(rust_ty) as u64;
let element_size_bits = match cx.type_kind(element_ty) {
TypeKind::Half => 16,
TypeKind::Float => 32,
TypeKind::Double => 64,
TypeKind::FP128 => 128,
TypeKind::Integer => cx.int_width(element_ty),
TypeKind::Pointer => cx.int_width(cx.isize_ty),
_ => bug!(
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
),
};
element_size_bits * element_count == 8192
}
_ => false,
}
}
@@ -1084,6 +1102,12 @@ fn autocast<'ll>(
)
}
}
(TypeKind::Vector, TypeKind::X86_AMX) => {
bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val])
}
(TypeKind::X86_AMX, TypeKind::Vector) => {
bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val])
}
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
}
}
+27 -2
View File
@@ -1,7 +1,7 @@
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8
//@ only-x86_64
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)]
#![crate_type = "lib"]
use std::simd::{f32x4, i16x8, i64x2};
@@ -10,6 +10,9 @@
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
#[repr(simd)]
pub struct Tile([i8; 1024]);
// CHECK-LABEL: @struct_autocast
#[no_mangle]
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
@@ -84,6 +87,22 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
foo(a)
}
// CHECK-LABEL: @amx_autocast
#[no_mangle]
pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile {
extern "unadjusted" {
#[link_name = "llvm.x86.tdpbuud.internal"]
fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile;
}
// CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]])
// CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]])
foo(m, n, k, a, b, c)
}
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
@@ -91,3 +110,9 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)
// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)