Rollup merge of #146181 - Flakebi:dynamic-shared-memory, r=ZuseZ4,Sa4dus,workingjubilee,RalfJung,nikic,kjetilkjeka,kulst

Add intrinsic for launch-sized workgroup memory on GPUs

Workgroup memory is a memory region that is shared between all
threads in a workgroup on GPUs. Workgroup memory can be allocated
statically or after compilation, when launching a gpu-kernel.
The intrinsic added here returns the pointer to the memory that is
allocated at launch-time.

# Interface

With this change, workgroup memory can be accessed in Rust by
calling the new `gpu_launch_sized_workgroup_mem<T>() -> *mut T`
intrinsic.

It returns the pointer to workgroup memory guaranteeing that it is
aligned to at least the alignment of `T`.
The pointer is dereferencable for the size specified when launching the
current gpu-kernel (which may be the size of `T` but can also be larger
or smaller or zero).

All calls to this intrinsic return a pointer to the same address.

See the intrinsic documentation for more details.

## Alternative Interfaces

It was also considered to expose dynamic workgroup memory as extern
static variables in Rust, like they are represented in LLVM IR.
However, due to the pointer not being guaranteed to be dereferencable
(that depends on the allocated size at runtime), such a global must be
zero-sized, which makes global variables a bad fit.

# Implementation Details

Workgroup memory in amdgpu and nvptx lives in address space 3.
Workgroup memory from a launch is implemented by creating an
external global variable in address space 3. The global is declared with
size 0, as the actual size is only known at runtime. It is defined
behavior in LLVM to access an external global outside the defined size.

There is no similar way to get the allocated size of launch-sized
workgroup memory on amdgpu an nvptx, so users have to pass this
out-of-band or rely on target specific ways for now.

Tracking issue: rust-lang/rust#135516
This commit is contained in:
Jonathan Brouwer
2026-04-25 23:07:48 +02:00
committed by GitHub
11 changed files with 193 additions and 9 deletions
+3
View File
@@ -1753,6 +1753,9 @@ pub fn index_by_increasing_offset(&self) -> impl ExactSizeIterator<Item = usize>
impl AddressSpace { impl AddressSpace {
/// LLVM's `0` address space. /// LLVM's `0` address space.
pub const ZERO: Self = AddressSpace(0); pub const ZERO: Self = AddressSpace(0);
/// The address space for workgroup memory on nvptx and amdgpu.
/// See e.g. the `gpu_launch_sized_workgroup_mem` intrinsic for details.
pub const GPU_WORKGROUP: Self = AddressSpace(3);
} }
/// How many scalable vectors are in a `BackendRepr::ScalableVector`? /// How many scalable vectors are in a `BackendRepr::ScalableVector`?
@@ -14,6 +14,7 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use itertools::Itertools; use itertools::Itertools;
use rustc_abi::AddressSpace;
use rustc_codegen_ssa::traits::{MiscCodegenMethods, TypeMembershipCodegenMethods}; use rustc_codegen_ssa::traits::{MiscCodegenMethods, TypeMembershipCodegenMethods};
use rustc_data_structures::fx::FxIndexSet; use rustc_data_structures::fx::FxIndexSet;
use rustc_middle::ty::{Instance, Ty}; use rustc_middle::ty::{Instance, Ty};
@@ -104,6 +105,28 @@ pub(crate) fn declare_global(&self, name: &str, ty: &'ll Type) -> &'ll Value {
) )
} }
} }
/// Declare a global value in a specific address space.
///
/// If theres a value with the same name already declared, the function will
/// return its Value instead.
pub(crate) fn declare_global_in_addrspace(
&self,
name: &str,
ty: &'ll Type,
addr_space: AddressSpace,
) -> &'ll Value {
debug!("declare_global(name={name:?}, addrspace={addr_space:?})");
unsafe {
llvm::LLVMRustGetOrInsertGlobalInAddrspace(
(**self).borrow().llmod,
name.as_c_char_ptr(),
name.len(),
ty,
addr_space.0,
)
}
}
} }
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
+45 -4
View File
@@ -3,8 +3,8 @@
use std::{assert_matches, iter, ptr}; use std::{assert_matches, iter, ptr};
use rustc_abi::{ use rustc_abi::{
Align, BackendRepr, Float, HasDataLayout, Integer, NumScalableVectors, Primitive, Size, AddressSpace, Align, BackendRepr, Float, HasDataLayout, Integer, NumScalableVectors, Primitive,
WrappingRange, Size, WrappingRange,
}; };
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
@@ -178,6 +178,7 @@ fn codegen_intrinsic_call(
span: Span, span: Span,
) -> Result<(), ty::Instance<'tcx>> { ) -> Result<(), ty::Instance<'tcx>> {
let tcx = self.tcx; let tcx = self.tcx;
let llvm_version = crate::llvm_util::get_version();
let name = tcx.item_name(instance.def_id()); let name = tcx.item_name(instance.def_id());
let fn_args = instance.args; let fn_args = instance.args;
@@ -194,7 +195,7 @@ fn codegen_intrinsic_call(
| sym::maximum_number_nsz_f64 | sym::maximum_number_nsz_f64
| sym::maximum_number_nsz_f128 | sym::maximum_number_nsz_f128
// Need at least LLVM 22 for `min/maximumnum` to not crash LLVM. // Need at least LLVM 22 for `min/maximumnum` to not crash LLVM.
if crate::llvm_util::get_version() >= (22, 0, 0) => if llvm_version >= (22, 0, 0) =>
{ {
let intrinsic_name = if name.as_str().starts_with("min") { let intrinsic_name = if name.as_str().starts_with("min") {
"llvm.minimumnum" "llvm.minimumnum"
@@ -420,7 +421,7 @@ fn codegen_intrinsic_call(
} }
// FIXME move into the branch below when LLVM 22 is the lowest version we support. // FIXME move into the branch below when LLVM 22 is the lowest version we support.
sym::carryless_mul if crate::llvm_util::get_version() >= (22, 0, 0) => { sym::carryless_mul if llvm_version >= (22, 0, 0) => {
let ty = args[0].layout.ty; let ty = args[0].layout.ty;
if !ty.is_integral() { if !ty.is_integral() {
tcx.dcx().emit_err(InvalidMonomorphization::BasicIntegerType { tcx.dcx().emit_err(InvalidMonomorphization::BasicIntegerType {
@@ -620,6 +621,46 @@ fn codegen_intrinsic_call(
return Ok(()); return Ok(());
} }
sym::gpu_launch_sized_workgroup_mem => {
// Generate an anonymous global per call, with these properties:
// 1. The global is in the address space for workgroup memory
// 2. It is an `external` global
// 3. It is correctly aligned for the pointee `T`
// All instances of extern addrspace(gpu_workgroup) globals are merged in the LLVM backend.
// The name is irrelevant.
// See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared
let name = if llvm_version < (23, 0, 0) && tcx.sess.target.arch == Arch::Nvptx64 {
// The auto-assigned name for extern shared globals in the nvptx backend does
// not compile in ptxas. Workaround this issue by assigning a name.
// Fixed in LLVM 23.
"gpu_launch_sized_workgroup_mem"
} else {
""
};
let global = self.declare_global_in_addrspace(
name,
self.type_array(self.type_i8(), 0),
AddressSpace::GPU_WORKGROUP,
);
let ty::RawPtr(inner_ty, _) = result.layout.ty.kind() else { unreachable!() };
// The alignment of the global is used to specify the *minimum* alignment that
// must be obeyed by the GPU runtime.
// When multiple of these global variables are used by a kernel, the maximum alignment is taken.
// See https://github.com/llvm/llvm-project/blob/a271d07488a85ce677674bbe8101b10efff58c95/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp#L821
let alignment = self.align_of(*inner_ty).bytes() as u32;
unsafe {
// FIXME Workaround the above issue by taking maximum alignment if the global existed
if tcx.sess.target.arch == Arch::Nvptx64 {
if alignment > llvm::LLVMGetAlignment(global) {
llvm::LLVMSetAlignment(global, alignment);
}
} else {
llvm::LLVMSetAlignment(global, alignment);
}
}
self.cx().const_pointercast(global, self.type_ptr())
}
sym::amdgpu_dispatch_ptr => { sym::amdgpu_dispatch_ptr => {
let val = self.call_intrinsic("llvm.amdgcn.dispatch.ptr", &[], &[]); let val = self.call_intrinsic("llvm.amdgcn.dispatch.ptr", &[], &[]);
// Relying on `LLVMBuildPointerCast` to produce an addrspacecast // Relying on `LLVMBuildPointerCast` to produce an addrspacecast
@@ -2003,6 +2003,13 @@ pub(crate) fn LLVMRustGetOrInsertGlobal<'a>(
NameLen: size_t, NameLen: size_t,
T: &'a Type, T: &'a Type,
) -> &'a Value; ) -> &'a Value;
pub(crate) fn LLVMRustGetOrInsertGlobalInAddrspace<'a>(
M: &'a Module,
Name: *const c_char,
NameLen: size_t,
T: &'a Type,
AddressSpace: c_uint,
) -> &'a Value;
pub(crate) fn LLVMRustGetNamedValue( pub(crate) fn LLVMRustGetNamedValue(
M: &Module, M: &Module,
Name: *const c_char, Name: *const c_char,
@@ -111,6 +111,7 @@ pub fn codegen_intrinsic_call(
sym::abort sym::abort
| sym::unreachable | sym::unreachable
| sym::cold_path | sym::cold_path
| sym::gpu_launch_sized_workgroup_mem
| sym::breakpoint | sym::breakpoint
| sym::amdgpu_dispatch_ptr | sym::amdgpu_dispatch_ptr
| sym::assert_zero_valid | sym::assert_zero_valid
@@ -130,6 +130,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::forget | sym::forget
| sym::frem_algebraic | sym::frem_algebraic
| sym::fsub_algebraic | sym::fsub_algebraic
| sym::gpu_launch_sized_workgroup_mem
| sym::is_val_statically_known | sym::is_val_statically_known
| sym::log2f16 | sym::log2f16
| sym::log2f32 | sym::log2f32
@@ -297,6 +298,7 @@ pub(crate) fn check_intrinsic_type(
sym::field_offset => (1, 0, vec![], tcx.types.usize), sym::field_offset => (1, 0, vec![], tcx.types.usize),
sym::rustc_peek => (1, 0, vec![param(0)], param(0)), sym::rustc_peek => (1, 0, vec![param(0)], param(0)),
sym::caller_location => (0, 0, vec![], tcx.caller_location_ty()), sym::caller_location => (0, 0, vec![], tcx.caller_location_ty()),
sym::gpu_launch_sized_workgroup_mem => (1, 0, vec![], Ty::new_mut_ptr(tcx, param(0))),
sym::assert_inhabited | sym::assert_zero_valid | sym::assert_mem_uninitialized_valid => { sym::assert_inhabited | sym::assert_zero_valid | sym::assert_mem_uninitialized_valid => {
(1, 0, vec![], tcx.types.unit) (1, 0, vec![], tcx.types.unit)
} }
@@ -299,10 +299,12 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
.getCallee()); .getCallee());
} }
extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M, // Get the global variable with the given name if it exists or create a new
const char *Name, // external global.
size_t NameLen, extern "C" LLVMValueRef
LLVMTypeRef Ty) { LLVMRustGetOrInsertGlobalInAddrspace(LLVMModuleRef M, const char *Name,
size_t NameLen, LLVMTypeRef Ty,
unsigned int AddressSpace) {
Module *Mod = unwrap(M); Module *Mod = unwrap(M);
auto NameRef = StringRef(Name, NameLen); auto NameRef = StringRef(Name, NameLen);
@@ -313,10 +315,24 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true); GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true);
if (!GV) if (!GV)
GV = new GlobalVariable(*Mod, unwrap(Ty), false, GV = new GlobalVariable(*Mod, unwrap(Ty), false,
GlobalValue::ExternalLinkage, nullptr, NameRef); GlobalValue::ExternalLinkage, nullptr, NameRef,
nullptr, GlobalValue::NotThreadLocal, AddressSpace);
return wrap(GV); return wrap(GV);
} }
// Get the global variable with the given name if it exists or create a new
// external global.
extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
const char *Name,
size_t NameLen,
LLVMTypeRef Ty) {
Module *Mod = unwrap(M);
unsigned int AddressSpace =
Mod->getDataLayout().getDefaultGlobalsAddressSpace();
return LLVMRustGetOrInsertGlobalInAddrspace(M, Name, NameLen, Ty,
AddressSpace);
}
// Must match the layout of `rustc_codegen_llvm::llvm::ffi::AttributeKind`. // Must match the layout of `rustc_codegen_llvm::llvm::ffi::AttributeKind`.
enum class LLVMRustAttributeKind { enum class LLVMRustAttributeKind {
AlwaysInline = 0, AlwaysInline = 0,
+1
View File
@@ -1033,6 +1033,7 @@
global_asm, global_asm,
global_registration, global_registration,
globs, globs,
gpu_launch_sized_workgroup_mem,
gt, gt,
guard, guard,
guard_patterns, guard_patterns,
+45
View File
@@ -5,6 +5,51 @@
#![unstable(feature = "gpu_intrinsics", issue = "none")] #![unstable(feature = "gpu_intrinsics", issue = "none")]
/// Returns the pointer to workgroup memory allocated at launch-time on GPUs.
///
/// Workgroup memory is a memory region that is shared between all threads in
/// the same workgroup. It is faster to access than other memory but pointers do not
/// work outside the workgroup where they were obtained.
/// Workgroup memory can be allocated statically or after compilation, when
/// launching a gpu-kernel. `gpu_launch_sized_workgroup_mem` returns the pointer to
/// the memory that is allocated at launch-time.
/// The size of this memory can differ between launches of a gpu-kernel, depending on
/// what is specified at launch-time.
/// However, the alignment is fixed by the kernel itself, at compile-time.
///
/// The returned pointer is the start of the workgroup memory region that is
/// allocated at launch-time.
/// All calls to `gpu_launch_sized_workgroup_mem` in a workgroup, independent of the
/// generic type, return the same address, so alias the same memory.
/// The returned pointer is aligned by at least the alignment of `T`.
///
/// If `gpu_launch_sized_workgroup_mem` is invoked multiple times with different
/// types that have different alignment, then you may only rely on the resulting
/// pointer having the alignment of `T` after a call to `gpu_launch_sized_workgroup_mem::<T>`
/// has occurred in the current program execution.
///
/// # Safety
///
/// The pointer is safe to dereference from the start (the returned pointer) up to the
/// size of workgroup memory that was specified when launching the current gpu-kernel.
/// This allocated size is not related in any way to `T`.
///
/// The user must take care of synchronizing access to workgroup memory between
/// threads in a workgroup. The usual data race requirements apply.
///
/// # Other APIs
///
/// CUDA and HIP call this dynamic shared memory, shared between threads in a block.
/// OpenCL and SYCL call this local memory, shared between threads in a work-group.
/// GLSL calls this shared memory, shared between invocations in a work group.
/// DirectX calls this groupshared memory, shared between threads in a thread-group.
#[must_use = "returns a pointer that does nothing unless used"]
#[rustc_intrinsic]
#[rustc_nounwind]
#[unstable(feature = "gpu_launch_sized_workgroup_mem", issue = "135513")]
#[cfg(any(target_arch = "amdgpu", target_arch = "nvptx64"))]
pub fn gpu_launch_sized_workgroup_mem<T>() -> *mut T;
/// Returns a pointer to the HSA kernel dispatch packet. /// Returns a pointer to the HSA kernel dispatch packet.
/// ///
/// A `gpu-kernel` on amdgpu is always launched through a kernel dispatch packet. /// A `gpu-kernel` on amdgpu is always launched through a kernel dispatch packet.
+4
View File
@@ -222,6 +222,10 @@ fn should_ignore(line: &str) -> bool {
|| static_regex!( || static_regex!(
"\\s*//@ \\!?(count|files|has|has-dir|hasraw|matches|matchesraw|snapshot)\\s.*" "\\s*//@ \\!?(count|files|has|has-dir|hasraw|matches|matchesraw|snapshot)\\s.*"
).is_match(line) ).is_match(line)
// Matching for FileCheck checks
|| static_regex!(
"\\s*// [a-zA-Z0-9-_]*:\\s.*"
).is_match(line)
} }
/// Returns `true` if `line` is allowed to be longer than the normal limit. /// Returns `true` if `line` is allowed to be longer than the normal limit.
@@ -0,0 +1,41 @@
// Checks that the GPU intrinsic to get launch-sized workgroup memory works
// and correctly aligns the `external addrspace(...) global`s over multiple calls.
//@ revisions: amdgpu nvptx-pre-llvm-23 nvptx-post-llvm-23
//@ compile-flags: --crate-type=rlib -Copt-level=1
//
//@ [amdgpu] compile-flags: --target amdgcn-amd-amdhsa -Ctarget-cpu=gfx900
//@ [amdgpu] needs-llvm-components: amdgpu
//@ [nvptx-pre-llvm-23] compile-flags: --target nvptx64-nvidia-cuda
//@ [nvptx-pre-llvm-23] needs-llvm-components: nvptx
//@ [nvptx-pre-llvm-23] max-llvm-major-version: 22
//@ [nvptx-post-llvm-23] compile-flags: --target nvptx64-nvidia-cuda
//@ [nvptx-post-llvm-23] needs-llvm-components: nvptx
//@ [nvptx-post-llvm-23] min-llvm-version: 23
//@ add-minicore
#![feature(intrinsics, no_core, rustc_attrs)]
#![no_core]
extern crate minicore;
#[rustc_intrinsic]
#[rustc_nounwind]
fn gpu_launch_sized_workgroup_mem<T>() -> *mut T;
// amdgpu-DAG: @[[SMALL:[^ ]+]] = external addrspace(3) global [0 x i8], align 4
// amdgpu-DAG: @[[BIG:[^ ]+]] = external addrspace(3) global [0 x i8], align 8
// amdgpu: ret { ptr, ptr } { ptr addrspacecast (ptr addrspace(3) @[[SMALL]] to ptr), ptr addrspacecast (ptr addrspace(3) @[[BIG]] to ptr) }
// nvptx-pre-llvm-23: @[[BIG:[^ ]+]] = external addrspace(3) global [0 x i8], align 8
// nvptx-pre-llvm-23: ret { ptr, ptr } { ptr addrspacecast (ptr addrspace(3) @[[BIG]] to ptr), ptr addrspacecast (ptr addrspace(3) @[[BIG]] to ptr) }
// nvptx-post-llvm-23-DAG: @[[SMALL:[^ ]+]] = external addrspace(3) global [0 x i8], align 4
// nvptx-post-llvm-23-DAG: @[[BIG:[^ ]+]] = external addrspace(3) global [0 x i8], align 8
// nvptx-post-llvm-23: ret { ptr, ptr } { ptr addrspacecast (ptr addrspace(3) @[[SMALL]] to ptr), ptr addrspacecast (ptr addrspace(3) @[[BIG]] to ptr) }
#[unsafe(no_mangle)]
pub fn fun() -> (*mut i32, *mut f64) {
let small = gpu_launch_sized_workgroup_mem::<i32>();
let big = gpu_launch_sized_workgroup_mem::<f64>(); // Increase alignment to 8
(small, big)
}