diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 0d3d682ece21..34b3f0d81a70 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1015,6 +1015,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll } } TypeKind::BFloat => rust_ty == cx.type_i16(), + TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => { + let element_ty = cx.element_type(rust_ty); + let element_count = cx.vector_length(rust_ty) as u64; + + let element_size_bits = match cx.type_kind(element_ty) { + TypeKind::Half => 16, + TypeKind::Float => 32, + TypeKind::Double => 64, + TypeKind::FP128 => 128, + TypeKind::Integer => cx.int_width(element_ty), + TypeKind::Pointer => cx.int_width(cx.isize_ty), + _ => bug!( + "Vector element type `{element_ty:?}` not one of integer, float or pointer" + ), + }; + + element_size_bits * element_count == 8192 + } _ => false, } } @@ -1084,6 +1102,12 @@ fn autocast<'ll>( ) } } + (TypeKind::Vector, TypeKind::X86_AMX) => { + bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val]) + } + (TypeKind::X86_AMX, TypeKind::Vector) => { + bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val]) + } _ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)` } } diff --git a/tests/codegen-llvm/inject-autocast.rs b/tests/codegen-llvm/inject-autocast.rs index fec9d3f0b195..cc74256ebbe8 100644 --- a/tests/codegen-llvm/inject-autocast.rs +++ b/tests/codegen-llvm/inject-autocast.rs @@ -1,7 +1,7 @@ -//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert +//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8 //@ only-x86_64 -#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)] +#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)] #![crate_type = "lib"] use std::simd::{f32x4, i16x8, i64x2}; @@ -10,6 +10,9 @@ pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2); // CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }> +#[repr(simd)] +pub struct Tile([i8; 1024]); + // CHECK-LABEL: @struct_autocast #[no_mangle] pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar { @@ -84,6 +87,22 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 { foo(a) } +// CHECK-LABEL: @amx_autocast +#[no_mangle] +pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile { + extern "unadjusted" { + #[link_name = "llvm.x86.tdpbuud.internal"] + fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile; + } + + // CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}}) + // CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]]) + // CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]]) + foo(m, n, k, a, b, c) +} + // CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>) // CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>) @@ -91,3 +110,9 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 { // CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>) // CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>) + +// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>) + +// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)