Auto merge of #153379 - TKanX:refactor/149164-simplify-autodiff-rlib, r=ZuseZ4

refactor(autodiff): Simplify Autodiff Handling of `rlib` Dependencies

### Summary:

Resolves the two FIXMEs left in rust-lang/rust#149033, per @bjorn3 guidance in [the discussion](https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880).

Closes rust-lang/rust#149164 

r? @ZuseZ4
cc @bjorn3
This commit is contained in:
bors
2026-03-11 02:03:25 +00:00
9 changed files with 101 additions and 122 deletions
@@ -6,7 +6,7 @@
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_data_structures::thin_vec::ThinVec;
use rustc_hir::attrs::RustcAutodiff;
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use rustc_target::callconv::PassMode;
use tracing::debug;
@@ -18,25 +18,23 @@
pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
fn_ptr_ty: Ty<'tcx>,
typing_env: TypingEnv<'tcx>,
da: &mut ThinVec<DiffActivity>,
) {
let fn_ty = instance.ty(tcx, typing_env);
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
}
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let fn_sig = fn_ptr_ty.fn_sig(tcx);
let sig = fn_sig.skip_binder();
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
let Ok(fn_abi) =
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
else {
bug!("failed to get fn_abi of instance with empty varargs");
bug!("failed to get fn_abi of fn_ptr with empty varargs");
};
let mut new_activities = vec![];
+4 -26
View File
@@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>(
let ret_ty = sig.output();
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
// Get source, diff, and attrs
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, source_params) => (def_id, source_params),
_ => bug!("invalid autodiff intrinsic args"),
};
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
source_id,
source_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
bug!("could not find source function")
};
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
let fn_to_diff = args[0].immediate();
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
@@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>(
adjust_activity_to_abi(
tcx,
fn_source,
source_fn_ptr_ty,
TypingEnv::fully_monomorphized(),
&mut diff_attrs.input_activity,
);
let fnc_tree =
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
// Build body
generate_enzyme_call(