Add immediate AMX intrinsics

This commit is contained in:
sayantn
2026-03-02 00:27:37 +05:30
parent e00790eb2c
commit 985cd2399a
@@ -398,6 +398,22 @@ pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
tcvtrowd2ps(TILE as i8, row).as_m512()
}
/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
/// elements to packed single-precision (32-bit) floating-point elements.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
}
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
@@ -414,6 +430,23 @@ pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
tcvtrowps2phh(TILE as i8, row).as_m512h()
}
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
}
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
@@ -430,6 +463,23 @@ pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
tcvtrowps2phl(TILE as i8, row).as_m512h()
}
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
}
/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0)]
@@ -444,6 +494,21 @@ pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
tilemovrow(TILE as i8, row).as_m512i()
}
/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tilemovrow, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_movrowi<const TILE: i32, const ROW: i32>() -> __m512i {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tilemovrowi(TILE as i8, ROW as u32).as_m512i()
}
#[allow(improper_ctypes)]
unsafe extern "C" {
#[link_name = "llvm.x86.ldtilecfg"]
@@ -492,12 +557,20 @@ pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
fn tmmultf32ps(dst: i8, a: i8, b: i8);
#[link_name = "llvm.x86.tcvtrowd2ps"]
fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
#[link_name = "llvm.x86.tcvtrowd2psi"]
fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
#[link_name = "llvm.x86.tcvtrowps2phh"]
fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phhi"]
fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phl"]
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phli"]
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tilemovrow"]
fn tilemovrow(tile: i8, row: u32) -> i32x16;
#[link_name = "llvm.x86.tilemovrowi"]
fn tilemovrowi(tile: i8, row: u32) -> i32x16;
}
#[cfg(test)]
@@ -1032,6 +1105,50 @@ fn test_tile_movrow() {
}
}
macro_rules! wrap_imm4 {
($name:ident :: <$TILE:literal>, $row:expr) => {
match $row {
0 => $name::<$TILE, 0>(),
1 => $name::<$TILE, 1>(),
2 => $name::<$TILE, 2>(),
3 => $name::<$TILE, 3>(),
4 => $name::<$TILE, 4>(),
5 => $name::<$TILE, 5>(),
6 => $name::<$TILE, 6>(),
7 => $name::<$TILE, 7>(),
8 => $name::<$TILE, 8>(),
9 => $name::<$TILE, 9>(),
10 => $name::<$TILE, 10>(),
11 => $name::<$TILE, 11>(),
12 => $name::<$TILE, 12>(),
13 => $name::<$TILE, 13>(),
14 => $name::<$TILE, 14>(),
15 => $name::<$TILE, 15>(),
_ => panic!("row index out of range"),
}
};
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_movrowi() {
unsafe {
_init_amx();
let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_movrowi::<0>, i);
assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
}
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowd2ps() {
unsafe {
@@ -1051,6 +1168,26 @@ fn test_tile_cvtrowd2ps() {
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowd2psi() {
unsafe {
_init_amx();
let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i);
assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
}
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phh() {
unsafe {
@@ -1073,6 +1210,28 @@ fn test_tile_cvtrowps2phh() {
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phhi() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i);
assert_eq!(
*row.as_f16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
);
}
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phl() {
unsafe {
@@ -1095,6 +1254,28 @@ fn test_tile_cvtrowps2phl() {
}
}
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phli() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i);
assert_eq!(
*row.as_f16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
);
}
}
}
#[simd_test(enable = "amx-tf32")]
fn test_tile_mmultf32ps() {
unsafe {