refactor(autodiff): Simplify rlib dep handling; use fn_ptr_ty in adjust_activity_to_abi, drop mono-collection & cross-crate-inline workarounds

This commit is contained in:
Tony Kan
2026-03-03 18:12:21 -08:00
parent d933cf483e
commit dd9922151f
5 changed files with 11 additions and 71 deletions
@@ -6,7 +6,7 @@
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_data_structures::thin_vec::ThinVec;
use rustc_hir::attrs::RustcAutodiff;
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use rustc_target::callconv::PassMode;
use tracing::debug;
@@ -18,25 +18,23 @@
pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
fn_ptr_ty: Ty<'tcx>,
typing_env: TypingEnv<'tcx>,
da: &mut ThinVec<DiffActivity>,
) {
let fn_ty = instance.ty(tcx, typing_env);
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
}
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let fn_sig = fn_ptr_ty.fn_sig(tcx);
let sig = fn_sig.skip_binder();
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
let Ok(fn_abi) =
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
else {
bug!("failed to get fn_abi of instance with empty varargs");
bug!("failed to get fn_abi of fn_ptr with empty varargs");
};
let mut new_activities = vec![];
+3 -1
View File
@@ -1374,9 +1374,11 @@ fn codegen_autodiff<'ll, 'tcx>(
bug!("could not find autodiff attrs")
};
let fn_ptr_ty =
Ty::new_fn_ptr(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()).fn_sig(tcx));
adjust_activity_to_abi(
tcx,
fn_source,
fn_ptr_ty,
TypingEnv::fully_monomorphized(),
&mut diff_attrs.input_activity,
);
@@ -35,11 +35,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
return true;
}
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
return true;
}
if find_attr!(tcx, def_id, RustcIntrinsic) {
// Intrinsic fallback bodies are always cross-crate inlineable.
// To ensure that the MIR inliner doesn't cluelessly try to inline fallback
@@ -205,8 +205,6 @@
//! this is not implemented however: a mono item will be produced
//! regardless of whether it is actually needed or not.
mod autodiff;
use std::cell::OnceCell;
use std::ops::ControlFlow;
@@ -240,7 +238,6 @@
use rustc_span::{DUMMY_SP, Span};
use tracing::{debug, instrument, trace};
use crate::collector::autodiff::collect_autodiff_fn;
use crate::errors::{
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
NoOptimizedMir, RecursionLimit,
@@ -990,8 +987,6 @@ fn visit_instance_use<'tcx>(
return;
}
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
collect_autodiff_fn(tcx, instance, intrinsic, output);
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any
@@ -1,50 +0,0 @@
use rustc_middle::bug;
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
use crate::collector::{MonoItems, create_fn_mono_item};
// Here, we force both primal and diff function to be collected in
// mono so this does not interfere in `autodiff` intrinsics
// codegen process. If they are unused, LLVM will remove them when
// compiling with O3.
// FIXME(autodiff): Remove this whole file, as per discussion in
// https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
pub(crate) fn collect_autodiff_fn<'tcx>(
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
intrinsic: IntrinsicDef,
output: &mut MonoItems<'tcx>,
) {
if intrinsic.name != rustc_span::sym::autodiff {
return;
};
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
}
fn collect_autodiff_fn_from_arg<'tcx>(
arg: GenericArg<'tcx>,
tcx: TyCtxt<'tcx>,
output: &mut MonoItems<'tcx>,
) {
let (instance, span) = match arg.kind() {
ty::GenericArgKind::Type(ty) => match *ty.kind() {
ty::FnDef(def_id, substs) => {
let span = tcx.def_span(def_id);
let instance = ty::Instance::expect_resolve(
tcx,
ty::TypingEnv::non_body_analysis(tcx, def_id),
def_id,
substs,
span,
);
(instance, span)
}
_ => bug!("expected autodiff function"),
},
_ => bug!("expected type when matching autodiff arg"),
};
output.push(create_fn_mono_item(tcx, instance, span));
}