Auto merge of #153379 - TKanX:refactor/149164-simplify-autodiff-rlib, r=ZuseZ4

refactor(autodiff): Simplify Autodiff Handling of `rlib` Dependencies

### Summary:

Resolves the two FIXMEs left in rust-lang/rust#149033, per @bjorn3 guidance in [the discussion](https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880).

Closes rust-lang/rust#149164 

r? @ZuseZ4
cc @bjorn3
This commit is contained in:
bors
2026-03-11 02:03:25 +00:00
9 changed files with 101 additions and 122 deletions
+49 -4
View File
@@ -20,7 +20,7 @@ mod llvm_enzyme {
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_hir::attrs::RustcAutodiff;
use rustc_span::{Ident, Span, Symbol, sym};
use rustc_span::{Ident, Span, Symbol, kw, sym};
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, trace};
@@ -197,7 +197,7 @@ pub(crate) fn expand_reverse(
/// }
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
/// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))
/// }
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -326,6 +326,7 @@ pub(crate) fn expand_with_mode(
primal,
first_ident(&meta_item_vec[0]),
span,
&sig,
&d_sig,
&generics,
is_impl,
@@ -496,18 +497,62 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
// Generate `autodiff` intrinsic call
// ```
// std::intrinsics::autodiff(source, diff, (args))
// std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))
// ```
fn call_autodiff(
ecx: &ExtCtxt<'_>,
primal: Ident,
diff: Ident,
span: Span,
p_sig: &FnSig,
d_sig: &FnSig,
generics: &Generics,
is_impl: bool,
) -> rustc_ast::Stmt {
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));
let fn_ptr_params: ThinVec<ast::Param> = p_sig
.decl
.inputs
.iter()
.map(|param| {
let ty = match &param.ty.kind {
TyKind::ImplicitSelf => self_ty(),
TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(
span,
TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),
),
TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {
ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))
}
_ => param.ty.clone(),
};
ast::Param {
attrs: ast::AttrVec::new(),
ty,
pat: Box::new(ecx.pat_wild(span)),
id: ast::DUMMY_NODE_ID,
span,
is_placeholder: false,
}
})
.collect();
let fn_ptr_ty = ecx.ty(
span,
TyKind::FnPtr(Box::new(ast::FnPtrTy {
safety: p_sig.header.safety,
ext: p_sig.header.ext,
generic_params: ThinVec::new(),
decl: Box::new(ast::FnDecl {
inputs: fn_ptr_params,
output: p_sig.decl.output.clone(),
}),
decl_span: span,
})),
);
let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
let tuple_expr = ecx.expr_tuple(
@@ -529,7 +574,7 @@ fn call_autodiff(
let call_expr = ecx.expr_call(
span,
ecx.expr_path(enzyme_path),
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),
);
ecx.stmt_expr(call_expr)
@@ -6,7 +6,7 @@
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_data_structures::thin_vec::ThinVec;
use rustc_hir::attrs::RustcAutodiff;
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use rustc_target::callconv::PassMode;
use tracing::debug;
@@ -18,25 +18,23 @@
pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
fn_ptr_ty: Ty<'tcx>,
typing_env: TypingEnv<'tcx>,
da: &mut ThinVec<DiffActivity>,
) {
let fn_ty = instance.ty(tcx, typing_env);
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
}
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let fn_sig = fn_ptr_ty.fn_sig(tcx);
let sig = fn_sig.skip_binder();
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
let Ok(fn_abi) =
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
else {
bug!("failed to get fn_abi of instance with empty varargs");
bug!("failed to get fn_abi of fn_ptr with empty varargs");
};
let mut new_activities = vec![];
+4 -26
View File
@@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>(
let ret_ty = sig.output();
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
// Get source, diff, and attrs
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, source_params) => (def_id, source_params),
_ => bug!("invalid autodiff intrinsic args"),
};
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
source_id,
source_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
bug!("could not find source function")
};
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
let fn_to_diff = args[0].immediate();
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
@@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>(
adjust_activity_to_abi(
tcx,
fn_source,
source_fn_ptr_ty,
TypingEnv::fully_monomorphized(),
&mut diff_attrs.input_activity,
);
let fnc_tree =
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
// Build body
generate_enzyme_call(
@@ -35,11 +35,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
return true;
}
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
return true;
}
if find_attr!(tcx, def_id, RustcIntrinsic) {
// Intrinsic fallback bodies are always cross-crate inlineable.
// To ensure that the MIR inliner doesn't cluelessly try to inline fallback
@@ -205,8 +205,6 @@
//! this is not implemented however: a mono item will be produced
//! regardless of whether it is actually needed or not.
mod autodiff;
use std::cell::OnceCell;
use std::ops::ControlFlow;
@@ -240,7 +238,6 @@
use rustc_span::{DUMMY_SP, Span};
use tracing::{debug, instrument, trace};
use crate::collector::autodiff::collect_autodiff_fn;
use crate::errors::{
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
NoOptimizedMir, RecursionLimit,
@@ -990,8 +987,6 @@ fn visit_instance_use<'tcx>(
return;
}
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
collect_autodiff_fn(tcx, instance, intrinsic, output);
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any
@@ -1,50 +0,0 @@
use rustc_middle::bug;
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
use crate::collector::{MonoItems, create_fn_mono_item};
// Here, we force both primal and diff function to be collected in
// mono so this does not interfere in `autodiff` intrinsics
// codegen process. If they are unused, LLVM will remove them when
// compiling with O3.
// FIXME(autodiff): Remove this whole file, as per discussion in
// https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
pub(crate) fn collect_autodiff_fn<'tcx>(
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
intrinsic: IntrinsicDef,
output: &mut MonoItems<'tcx>,
) {
if intrinsic.name != rustc_span::sym::autodiff {
return;
};
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
}
fn collect_autodiff_fn_from_arg<'tcx>(
arg: GenericArg<'tcx>,
tcx: TyCtxt<'tcx>,
output: &mut MonoItems<'tcx>,
) {
let (instance, span) = match arg.kind() {
ty::GenericArgKind::Type(ty) => match *ty.kind() {
ty::FnDef(def_id, substs) => {
let span = tcx.def_span(def_id);
let instance = ty::Instance::expect_resolve(
tcx,
ty::TypingEnv::non_body_analysis(tcx, def_id),
def_id,
substs,
span,
);
(instance, span)
}
_ => bug!("expected autodiff function"),
},
_ => bug!("expected type when matching autodiff arg"),
};
output.push(create_fn_mono_item(tcx, instance, span));
}
+28 -15
View File
@@ -35,7 +35,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
df1::<>, (x, bx_0, y))
}
#[rustc_autodiff]
pub fn f2(x: &[f64], y: f64) -> f64 {
@@ -43,7 +44,8 @@ pub fn f2(x: &[f64], y: f64) -> f64 {
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
::core::intrinsics::autodiff(f2::<> as fn(_: &[f64], _: f64) -> f64,
df2::<>, (x, bx_0, y))
}
#[rustc_autodiff]
pub fn f3(x: &[f64], y: f64) -> f64 {
@@ -51,27 +53,33 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
df3::<>, (x, bx_0, y))
}
#[rustc_autodiff]
pub fn f4() {}
#[rustc_autodiff(Forward, 1, None)]
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
pub fn df4() -> () {
::core::intrinsics::autodiff(f4::<> as fn(), df4::<>, ())
}
#[rustc_autodiff]
pub fn f5(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
df5_y::<>, (x, y, by_0))
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
df5_x::<>, (x, bx_0, y))
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
df5_rev::<>, (x, dx_0, y, dret))
}
struct DoesNotImplDefault;
#[rustc_autodiff]
@@ -80,13 +88,14 @@ pub fn f6() -> DoesNotImplDefault {
}
#[rustc_autodiff(Forward, 1, Const)]
pub fn df6() -> DoesNotImplDefault {
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
::core::intrinsics::autodiff(f6::<> as fn() -> DoesNotImplDefault,
df6::<>, ())
}
#[rustc_autodiff]
pub fn f7(x: f32) -> () {}
#[rustc_autodiff(Forward, 1, Const, None)]
pub fn df7(x: f32) -> () {
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
::core::intrinsics::autodiff(f7::<> as fn(_: f32) -> (), df7::<>, (x,))
}
#[no_mangle]
#[rustc_autodiff]
@@ -94,29 +103,32 @@ fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Forward, 4, Dual, Dual)]
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 5usize] {
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_3::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 4usize] {
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_2::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_1::<>,
(x, bx_0))
}
pub fn f9() {
#[rustc_autodiff]
fn inner(x: f32) -> f32 { x * x }
#[rustc_autodiff(Forward, 1, Dual, Dual)]
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
d_inner_2::<>, (x, bx_0))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
d_inner_1::<>, (x, bx_0))
}
}
#[rustc_autodiff]
@@ -124,6 +136,7 @@ pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
::core::intrinsics::autodiff(f10::<T> as fn(_: &T) -> T, d_square::<T>,
(x, dx_0, dret))
}
fn main() {}
+10 -5
View File
@@ -28,32 +28,37 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
df1::<>, (x, dx_0, y, dret))
}
#[rustc_autodiff]
pub fn f2() {}
#[rustc_autodiff(Reverse, 1, None)]
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
pub fn df2() { ::core::intrinsics::autodiff(f2::<> as fn(), df2::<>, ()) }
#[rustc_autodiff]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
df3::<>, (x, dx_0, y, dret))
}
enum Foo { Reverse, }
use Foo::Reverse;
#[rustc_autodiff]
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Reverse, 1, Const, None)]
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
pub fn df4(x: f32) {
::core::intrinsics::autodiff(f4::<> as fn(_: f32), df4::<>, (x,))
}
#[rustc_autodiff]
pub fn f5(x: *const f32, y: &f32) {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
::core::intrinsics::autodiff(f5::<> as fn(_: *const f32, _: &f32),
df5::<>, (x, dx_0, y, dy_0))
}
fn main() {}
+2 -2
View File
@@ -30,7 +30,7 @@ impl MyTrait for Foo {
}
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
(self, x, dret))
::core::intrinsics::autodiff(Self::f::<> as
fn(_: &Self, _: f64) -> f64, Self::df::<>, (self, x, dret))
}
}