feat: dlopen Enzyme

This commit is contained in:
sgasho
2025-11-24 20:50:08 +09:00
parent ee447067e1
commit ddd5aad8a3
13 changed files with 536 additions and 246 deletions
+1
View File
@@ -3613,6 +3613,7 @@ dependencies = [
"gimli 0.31.1",
"itertools",
"libc",
"libloading 0.9.0",
"measureme",
"object 0.37.3",
"rustc-demangle",
+2 -2
View File
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
gimli = "0.31"
itertools = "0.12"
libc = "0.2"
libloading = { version = "0.9.0", optional = true }
measureme = "12.0.1"
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
rustc-demangle = "0.1.21"
@@ -46,7 +47,6 @@ tracing = "0.1"
[features]
# tidy-alphabetical-start
check_only = ["rustc_llvm/check_only"]
llvm_enzyme = []
llvm_enzyme = ["dep:libloading"]
llvm_offload = []
# tidy-alphabetical-end
+17 -15
View File
@@ -528,31 +528,37 @@ fn thin_lto(
}
}
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
#[cfg(feature = "llvm_enzyme")]
pub(crate) fn enable_autodiff_settings(
sysroot: &rustc_session::config::Sysroot,
ad: &[config::AutoDiff],
) {
let mut enzyme = llvm::EnzymeWrapper::get_or_init(sysroot);
for val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
enzyme.set_print_perf(true);
}
config::AutoDiff::PrintAA => {
llvm::set_print_activity(true);
enzyme.set_print_activity(true);
}
config::AutoDiff::PrintTA => {
llvm::set_print_type(true);
enzyme.set_print_type(true);
}
config::AutoDiff::PrintTAFn(fun) => {
llvm::set_print_type(true); // Enable general type printing
llvm::set_print_type_fun(&fun); // Set specific function to analyze
enzyme.set_print_type(true); // Enable general type printing
enzyme.set_print_type_fun(&fun); // Set specific function to analyze
}
config::AutoDiff::Inline => {
llvm::set_inline(true);
enzyme.set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(true);
enzyme.set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
enzyme.set_print(true);
}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
@@ -571,9 +577,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
}
}
// This helps with handling enums for now.
llvm::set_strict_aliasing(false);
enzyme.set_strict_aliasing(false);
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
llvm::set_rust_rules(true);
enzyme.set_rust_rules(true);
}
pub(crate) fn run_pass_manager(
@@ -607,10 +613,6 @@ pub(crate) fn run_pass_manager(
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
};
if enable_ad {
enable_autodiff_settings(&config.autodiff);
}
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
}
@@ -730,6 +730,13 @@ fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
let llvm_plugins = config.llvm_plugins.join(",");
let enzyme_fn = if consider_ad {
let wrapper = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);
wrapper.registerEnzymeAndPassPipeline
} else {
std::ptr::null()
};
let result = unsafe {
llvm::LLVMRustOptimize(
module.module_llvm.llmod(),
@@ -749,7 +756,7 @@ fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
enzyme_fn,
print_before_enzyme,
print_after_enzyme,
print_passes,
+12
View File
@@ -240,6 +240,18 @@ fn name(&self) -> &'static str {
fn init(&self, sess: &Session) {
llvm_util::init(sess); // Make sure llvm is inited
#[cfg(feature = "llvm_enzyme")]
{
use rustc_session::config::AutoDiff;
if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
{
use crate::back::lto::enable_autodiff_settings;
enable_autodiff_settings(&sess.opts.sysroot, &sess.opts.unstable_opts.autodiff);
}
}
}
}
fn provide(&self, providers: &mut Providers) {
+476 -172
View File
@@ -91,102 +91,363 @@ pub(crate) enum LLVMRustVerifierFailureAction {
#[cfg(feature = "llvm_enzyme")]
pub(crate) mod Enzyme_AD {
use std::ffi::{CString, c_char};
use std::ffi::{c_char, c_void};
use std::sync::{Mutex, MutexGuard, OnceLock};
use libc::c_void;
use rustc_middle::bug;
use rustc_session::config::{Sysroot, host_tuple};
use rustc_session::filesearch;
use super::{CConcreteType, CTypeTreeRef, Context};
use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef;
type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef;
type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef;
type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef);
type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool;
type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64);
type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef);
type EnzymeTypeTreeShiftIndiciesEqFn =
unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64);
type EnzymeTypeTreeInsertEqFn =
unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
#[allow(non_snake_case)]
pub(crate) struct EnzymeWrapper {
EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
EnzymePrintPerf: *mut c_void,
EnzymePrintActivity: *mut c_void,
EnzymePrintType: *mut c_void,
EnzymeFunctionToAnalyze: *mut c_void,
EnzymePrint: *mut c_void,
EnzymeStrictAliasing: *mut c_void,
EnzymeInline: *mut c_void,
EnzymeMaxTypeDepth: *mut c_void,
RustTypeRules: *mut c_void,
looseTypeAnalysis: *mut c_void,
EnzymeSetCLBool: EnzymeSetCLBoolFn,
EnzymeSetCLString: EnzymeSetCLStringFn,
pub registerEnzymeAndPassPipeline: *const c_void,
lib: libloading::Library,
}
// TypeTree functions
unsafe extern "C" {
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
unsafe impl Sync for EnzymeWrapper {}
unsafe impl Send for EnzymeWrapper {}
fn load_ptr_by_symbol_mut_void(
lib: &libloading::Library,
bytes: &[u8],
) -> Result<*mut c_void, Box<dyn std::error::Error>> {
unsafe {
let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
// libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some
let s = s.try_as_raw_ptr().unwrap();
Ok(s)
}
}
// e.g.
// load_ptrs_by_symbols_mut_void(ABC, XYZ);
// =>
// let ABC = load_ptr_mut_void(&lib, b"ABC")?;
// let XYZ = load_ptr_mut_void(&lib, b"XYZ")?;
macro_rules! load_ptrs_by_symbols_mut_void {
($lib:expr, $($name:ident),* $(,)?) => {
$(
#[allow(non_snake_case)]
let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?;
)*
};
}
// e.g.
// load_ptrs_by_symbols_fn(ABC: ABCFn, XYZ: XYZFn);
// =>
// let ABC: libloading::Symbol<'_, ABCFn> = unsafe { lib.get(b"ABC")? };
// let XYZ: libloading::Symbol<'_, XYZFn> = unsafe { lib.get(b"XYZ")? };
macro_rules! load_ptrs_by_symbols_fn {
($lib:expr, $($name:ident : $ty:ty),* $(,)?) => {
$(
#[allow(non_snake_case)]
let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? };
)*
};
}
static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
impl EnzymeWrapper {
/// Initialize EnzymeWrapper with the given sysroot if not already initialized.
/// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
pub(crate) fn get_or_init(
sysroot: &rustc_session::config::Sysroot,
) -> MutexGuard<'static, Self> {
ENZYME_INSTANCE
.get_or_init(|| {
Self::call_dynamic(sysroot)
.unwrap_or_else(|e| bug!("failed to load Enzyme: {e}"))
.into()
})
.lock()
.unwrap()
}
/// Get the EnzymeWrapper instance. Panics if not initialized.
pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
ENZYME_INSTANCE
.get()
.expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
.lock()
.unwrap()
}
pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
unsafe { (self.EnzymeNewTypeTree)() }
}
pub(crate) fn new_type_tree_ct(
&self,
t: CConcreteType,
ctx: &Context,
) -> *mut EnzymeTypeTree {
unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) }
}
pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
unsafe { (self.EnzymeNewTypeTreeTR)(tree) }
}
pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
unsafe { (self.EnzymeFreeTypeTree)(tree) }
}
pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) }
}
pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) }
}
pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
unsafe { (self.EnzymeTypeTreeData0Eq)(tree) }
}
pub(crate) fn shift_indicies_eq(
&self,
tree: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
);
pub(crate) fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
) {
unsafe {
(self.EnzymeTypeTreeShiftIndiciesEq)(
tree,
data_layout,
offset,
max_size,
add_offset,
)
}
}
pub(crate) fn tree_insert_eq(
&self,
tree: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
}
) {
unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) }
}
unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
static mut EnzymePrintType: c_void;
static mut EnzymeFunctionToAnalyze: c_void;
static mut EnzymePrint: c_void;
static mut EnzymeStrictAliasing: c_void;
static mut looseTypeAnalysis: c_void;
static mut EnzymeInline: c_void;
static mut RustTypeRules: c_void;
}
pub(crate) fn set_print_perf(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
unsafe { (self.EnzymeTypeTreeToString)(tree) }
}
}
pub(crate) fn set_print_activity(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
}
}
pub(crate) fn set_print_type(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
pub(crate) fn get_max_type_depth(&self) -> usize {
unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
}
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
let c_fun_name = CString::new(fun_name).unwrap();
unsafe {
EnzymeSetCLString(
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
c_fun_name.as_ptr() as *const c_char,
pub(crate) fn set_print_perf(&mut self, print: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8);
}
}
pub(crate) fn set_print_activity(&mut self, print: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8);
}
}
pub(crate) fn set_print_type(&mut self, print: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8);
}
}
pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
let c_fun_name = std::ffi::CString::new(fun_name)
.unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}"));
unsafe {
(self.EnzymeSetCLString)(
self.EnzymeFunctionToAnalyze,
c_fun_name.as_ptr() as *const c_char,
);
}
}
pub(crate) fn set_print(&mut self, print: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymePrint, print as u8);
}
}
pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8);
}
}
pub(crate) fn set_loose_types(&mut self, loose: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8);
}
}
pub(crate) fn set_inline(&mut self, val: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.EnzymeInline, val as u8);
}
}
pub(crate) fn set_rust_rules(&mut self, val: bool) {
unsafe {
(self.EnzymeSetCLBool)(self.RustTypeRules, val as u8);
}
}
#[allow(non_snake_case)]
fn call_dynamic(
sysroot: &rustc_session::config::Sysroot,
) -> Result<Self, Box<dyn std::error::Error>> {
let enzyme_path = Self::get_enzyme_path(sysroot)?;
let lib = unsafe { libloading::Library::new(enzyme_path)? };
load_ptrs_by_symbols_fn!(
lib,
EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
EnzymeSetCLBool: EnzymeSetCLBoolFn,
EnzymeSetCLString: EnzymeSetCLStringFn,
);
load_ptrs_by_symbols_mut_void!(
lib,
registerEnzymeAndPassPipeline,
EnzymePrintPerf,
EnzymePrintActivity,
EnzymePrintType,
EnzymeFunctionToAnalyze,
EnzymePrint,
EnzymeStrictAliasing,
EnzymeInline,
EnzymeMaxTypeDepth,
RustTypeRules,
looseTypeAnalysis,
);
Ok(Self {
EnzymeNewTypeTree,
EnzymeNewTypeTreeCT,
EnzymeNewTypeTreeTR,
EnzymeFreeTypeTree,
EnzymeMergeTypeTree,
EnzymeTypeTreeOnlyEq,
EnzymeTypeTreeData0Eq,
EnzymeTypeTreeShiftIndiciesEq,
EnzymeTypeTreeInsertEq,
EnzymeTypeTreeToString,
EnzymeTypeTreeToStringFree,
EnzymePrintPerf,
EnzymePrintActivity,
EnzymePrintType,
EnzymeFunctionToAnalyze,
EnzymePrint,
EnzymeStrictAliasing,
EnzymeInline,
EnzymeMaxTypeDepth,
RustTypeRules,
looseTypeAnalysis,
EnzymeSetCLBool,
EnzymeSetCLString,
registerEnzymeAndPassPipeline,
lib,
})
}
}
pub(crate) fn set_print(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
}
}
pub(crate) fn set_strict_aliasing(strict: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
}
}
pub(crate) fn set_loose_types(loose: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
}
}
pub(crate) fn set_inline(val: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
}
}
pub(crate) fn set_rust_rules(val: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, String> {
let llvm_version_major = unsafe { LLVMRustVersionMajor() };
let path_buf = sysroot
.all_paths()
.map(|sysroot_path| {
filesearch::make_target_lib_path(sysroot_path, host_tuple())
.join("lib")
.with_file_name(format!("libEnzyme-{llvm_version_major}"))
.with_extension(std::env::consts::DLL_EXTENSION)
})
.find(|f| f.exists())
.ok_or_else(|| {
let candidates = sysroot
.all_paths()
.map(|p| p.join("lib").display().to_string())
.collect::<Vec<String>>()
.join("\n* ");
format!(
"failed to find a `libEnzyme-{llvm_version_major}` folder \
in the sysroot candidates:\n* {candidates}"
)
})?;
Ok(path_buf
.to_str()
.ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))?
.to_string())
}
}
}
@@ -198,111 +459,156 @@ pub(crate) fn set_rust_rules(val: bool) {
pub(crate) mod Fallback_AD {
#![allow(unused_variables)]
use std::ffi::c_void;
use std::sync::{Mutex, MutexGuard};
use libc::c_char;
use rustc_codegen_ssa::back::write::CodegenContext;
use rustc_codegen_ssa::traits::WriteBackendMethods;
use super::{CConcreteType, CTypeTreeRef, Context};
use super::{CConcreteType, CTypeTreeRef, Context, EnzymeTypeTree};
// TypeTree function fallbacks
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
unimplemented!()
pub(crate) struct EnzymeWrapper {
pub registerEnzymeAndPassPipeline: *const c_void,
}
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
unimplemented!()
}
impl EnzymeWrapper {
pub(crate) fn get_or_init(
_sysroot: &rustc_session::config::Sysroot,
) -> MutexGuard<'static, Self> {
unimplemented!("Enzyme not available: build with llvm_enzyme feature")
}
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) fn init<'a, B: WriteBackendMethods>(
_cgcx: &'a CodegenContext<B>,
) -> &'static Mutex<Self> {
unimplemented!("Enzyme not available: build with llvm_enzyme feature")
}
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
unimplemented!()
}
pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
unimplemented!("Enzyme not available: build with llvm_enzyme feature")
}
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
unimplemented!()
}
pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
unimplemented!()
}
pub(crate) fn new_type_tree_ct(
&self,
t: CConcreteType,
ctx: &Context,
) -> *mut EnzymeTypeTree {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
unimplemented!()
}
pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
) {
unimplemented!()
}
pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
) {
unimplemented!()
}
pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
unimplemented!()
}
pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
unimplemented!()
}
pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
unimplemented!()
}
pub(crate) fn set_inline(val: bool) {
unimplemented!()
}
pub(crate) fn set_print_perf(print: bool) {
unimplemented!()
}
pub(crate) fn set_print_activity(print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type(print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
unimplemented!()
}
pub(crate) fn set_print(print: bool) {
unimplemented!()
}
pub(crate) fn set_strict_aliasing(strict: bool) {
unimplemented!()
}
pub(crate) fn set_loose_types(loose: bool) {
unimplemented!()
}
pub(crate) fn set_rust_rules(val: bool) {
unimplemented!()
pub(crate) fn shift_indicies_eq(
&self,
tree: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
) {
unimplemented!()
}
pub(crate) fn tree_insert_eq(
&self,
tree: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
) {
unimplemented!()
}
pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
unimplemented!()
}
pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
unimplemented!()
}
pub(crate) fn get_max_type_depth(&self) -> usize {
unimplemented!()
}
pub(crate) fn set_inline(&mut self, val: bool) {
unimplemented!()
}
pub(crate) fn set_print_perf(&mut self, print: bool) {
unimplemented!()
}
pub(crate) fn set_print_activity(&mut self, print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type(&mut self, print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
unimplemented!()
}
pub(crate) fn set_print(&mut self, print: bool) {
unimplemented!()
}
pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
unimplemented!()
}
pub(crate) fn set_loose_types(&mut self, loose: bool) {
unimplemented!()
}
pub(crate) fn set_rust_rules(&mut self, val: bool) {
unimplemented!()
}
}
}
impl TypeTree {
pub(crate) fn new() -> TypeTree {
let inner = unsafe { EnzymeNewTypeTree() };
let wrapper = EnzymeWrapper::get_instance();
let inner = wrapper.new_type_tree();
TypeTree { inner }
}
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
let wrapper = EnzymeWrapper::get_instance();
let inner = wrapper.new_type_tree_ct(t, ctx);
TypeTree { inner }
}
pub(crate) fn merge(self, other: Self) -> Self {
unsafe {
EnzymeMergeTypeTree(self.inner, other.inner);
}
let wrapper = EnzymeWrapper::get_instance();
wrapper.merge_type_tree(self.inner, other.inner);
drop(other);
self
}
@@ -316,37 +622,36 @@ pub(crate) fn shift(
add_offset: usize,
) -> Self {
let layout = std::ffi::CString::new(layout).unwrap();
unsafe {
EnzymeTypeTreeShiftIndiciesEq(
self.inner,
layout.as_ptr(),
offset as i64,
max_size as i64,
add_offset as u64,
);
}
let wrapper = EnzymeWrapper::get_instance();
wrapper.shift_indicies_eq(
self.inner,
layout.as_ptr(),
offset as i64,
max_size as i64,
add_offset as u64,
);
self
}
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
unsafe {
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
}
let wrapper = EnzymeWrapper::get_instance();
wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
}
}
impl Clone for TypeTree {
fn clone(&self) -> Self {
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
let wrapper = EnzymeWrapper::get_instance();
let inner = wrapper.new_type_tree_tr(self.inner);
TypeTree { inner }
}
}
impl std::fmt::Display for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
let wrapper = EnzymeWrapper::get_instance();
let ptr = wrapper.tree_to_string(self.inner);
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
match cstr.to_str() {
Ok(x) => write!(f, "{}", x)?,
@@ -354,9 +659,7 @@ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
}
// delete C string pointer
unsafe {
EnzymeTypeTreeToStringFree(ptr);
}
wrapper.tree_to_string_free(ptr);
Ok(())
}
@@ -370,6 +673,7 @@ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl Drop for TypeTree {
fn drop(&mut self) {
unsafe { EnzymeFreeTypeTree(self.inner) }
let wrapper = EnzymeWrapper::get_instance();
wrapper.free_type_tree(self.inner)
}
}
+1 -1
View File
@@ -2411,7 +2411,7 @@ pub(crate) fn LLVMRustOptimize<'a>(
LoopVectorize: bool,
DisableSimplifyLibCalls: bool,
EmitLifetimeMarkers: bool,
RunEnzyme: bool,
RunEnzyme: *const c_void,
PrintBeforeEnzyme: bool,
PrintAfterEnzyme: bool,
PrintPasses: bool,
+7 -4
View File
@@ -2,6 +2,7 @@
#[cfg(feature = "llvm_enzyme")]
use {
crate::attributes,
crate::llvm::EnzymeWrapper,
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
std::ffi::{CString, c_char, c_uint},
};
@@ -77,7 +78,8 @@ pub(crate) fn add_tt<'ll>(
for (i, input) in inputs.iter().enumerate() {
unsafe {
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let enzyme_wrapper = EnzymeWrapper::get_instance();
let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let attr = llvm::LLVMCreateStringAttribute(
@@ -89,13 +91,14 @@ pub(crate) fn add_tt<'ll>(
);
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
}
}
unsafe {
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let enzyme_wrapper = EnzymeWrapper::get_instance();
let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let ret_attr = llvm::LLVMCreateStringAttribute(
@@ -107,7 +110,7 @@ pub(crate) fn add_tt<'ll>(
);
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
}
}
@@ -29,6 +29,7 @@
use rustc_session::Session;
use rustc_session::config::{
self, CrateType, Lto, OutFileName, OutputFilenames, OutputType, Passes, SwitchWithOptPath,
Sysroot,
};
use rustc_span::source_map::SourceMap;
use rustc_span::{FileName, InnerSpan, Span, SpanData, sym};
@@ -346,6 +347,7 @@ pub struct CodegenContext<B: WriteBackendMethods> {
pub split_debuginfo: rustc_target::spec::SplitDebuginfo,
pub split_dwarf_kind: rustc_session::config::SplitDwarfKind,
pub pointer_size: Size,
pub sysroot: Sysroot,
/// Emitter to use for diagnostics produced during codegen.
pub diag_emitter: SharedEmitter,
@@ -1317,6 +1319,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend,
pointer_size: tcx.data_layout.pointer_size(),
invocation_temp: sess.invocation_temp.clone(),
sysroot: sess.opts.sysroot.clone(),
};
// This is the "main loop" of parallel work happening for parallel codegen.
@@ -550,17 +550,8 @@ struct LLVMRustSanitizerOptions {
bool SanitizeKernelAddressRecover;
};
// This symbol won't be available or used when Enzyme is not enabled.
// Always set AugmentPassBuilder to true, since it registers optimizations which
// will improve the performance for Enzyme.
#ifdef ENZYME
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
/* augmentPassBuilder */ bool);
extern "C" {
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
}
#endif
extern "C" typedef void (*registerEnzymeAndPassPipelineFn)(
llvm::PassBuilder &PB, bool augment);
extern "C" LLVMRustResult LLVMRustOptimize(
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
@@ -569,8 +560,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
bool PrintAfterEnzyme, bool PrintPasses,
bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr,
bool PrintBeforeEnzyme, bool PrintAfterEnzyme, bool PrintPasses,
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
const char *PGOUsePath, bool InstrumentCoverage,
const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -907,8 +898,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
}
// now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) {
// With dlopen, ENZYME macro may not be defined, so check EnzymePtr directly
if (EnzymePtr) {
if (PrintBeforeEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
@@ -916,29 +907,19 @@ extern "C" LLVMRustResult LLVMRustOptimize(
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
registerEnzymeAndPassPipeline(PB, false);
EnzymePtr(PB, false);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}
// Check if PrintTAFn was used and add type analysis pass if needed
if (!EnzymeFunctionToAnalyze.empty()) {
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}
}
if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
}
#endif
if (PrintPasses) {
// Print all passes from the PM:
std::string Pipeline;
@@ -1791,18 +1791,6 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
GV.setSanitizerMetadata(MD);
}
#ifdef ENZYME
extern "C" {
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
}
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
#else
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
return 6; // Default fallback depth
}
#endif
// Statically assert that the fixed metadata kind IDs declared in
// `metadata_kind.rs` match the ones actually used by LLVM.
#define FIXED_MD_KIND(VARIANT, VALUE) \
@@ -1232,19 +1232,6 @@ pub fn rustc_cargo(
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
cargo.rustflag("-Zon-broken-pipe=kill");
// We want to link against registerEnzyme and in the future we want to use additional
// functionality from Enzyme core. For that we need to link against Enzyme.
if builder.config.llvm_enzyme {
let arch = builder.build.host_target;
let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib");
cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path"));
if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) {
let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config);
cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}"));
}
}
// Building with protected visibility reduces the number of dynamic relocations needed, giving
// us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object
// with direct references to protected symbols, so for now we only use protected symbols if
+2
View File
@@ -50,6 +50,8 @@ unstalled = "unstalled"
debug_aranges = "debug_aranges"
DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME = "DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME"
EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq"
EnzymeTypeTreeShiftIndiciesEqFn = "EnzymeTypeTreeShiftIndiciesEqFn"
shift_indicies_eq = "shift_indicies_eq"
ERRNO_ACCES = "ERRNO_ACCES"
ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS = "ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS"
ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC"