mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-26 13:01:27 +03:00
Add autocast for x86_amx
This commit is contained in:
@@ -1015,6 +1015,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
|
||||
}
|
||||
}
|
||||
TypeKind::BFloat => rust_ty == cx.type_i16(),
|
||||
TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => {
|
||||
let element_ty = cx.element_type(rust_ty);
|
||||
let element_count = cx.vector_length(rust_ty) as u64;
|
||||
|
||||
let element_size_bits = match cx.type_kind(element_ty) {
|
||||
TypeKind::Half => 16,
|
||||
TypeKind::Float => 32,
|
||||
TypeKind::Double => 64,
|
||||
TypeKind::FP128 => 128,
|
||||
TypeKind::Integer => cx.int_width(element_ty),
|
||||
TypeKind::Pointer => cx.int_width(cx.isize_ty),
|
||||
_ => bug!(
|
||||
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
|
||||
),
|
||||
};
|
||||
|
||||
element_size_bits * element_count == 8192
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -1084,6 +1102,12 @@ fn autocast<'ll>(
|
||||
)
|
||||
}
|
||||
}
|
||||
(TypeKind::Vector, TypeKind::X86_AMX) => {
|
||||
bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val])
|
||||
}
|
||||
(TypeKind::X86_AMX, TypeKind::Vector) => {
|
||||
bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val])
|
||||
}
|
||||
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
|
||||
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8
|
||||
//@ only-x86_64
|
||||
|
||||
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
|
||||
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)]
|
||||
#![crate_type = "lib"]
|
||||
|
||||
use std::simd::{f32x4, i16x8, i64x2};
|
||||
@@ -10,6 +10,9 @@
|
||||
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
|
||||
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
|
||||
|
||||
#[repr(simd)]
|
||||
pub struct Tile([i8; 1024]);
|
||||
|
||||
// CHECK-LABEL: @struct_autocast
|
||||
#[no_mangle]
|
||||
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
|
||||
@@ -84,6 +87,22 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
|
||||
foo(a)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @amx_autocast
|
||||
#[no_mangle]
|
||||
pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile {
|
||||
extern "unadjusted" {
|
||||
#[link_name = "llvm.x86.tdpbuud.internal"]
|
||||
fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile;
|
||||
}
|
||||
|
||||
// CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
|
||||
// CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
|
||||
// CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
|
||||
// CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]])
|
||||
// CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]])
|
||||
foo(m, n, k, a, b, c)
|
||||
}
|
||||
|
||||
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
|
||||
|
||||
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
|
||||
@@ -91,3 +110,9 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
|
||||
// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)
|
||||
|
||||
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
|
||||
|
||||
// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
|
||||
|
||||
// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)
|
||||
|
||||
// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)
|
||||
|
||||
Reference in New Issue
Block a user