mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Auto merge of #153379 - TKanX:refactor/149164-simplify-autodiff-rlib, r=ZuseZ4
refactor(autodiff): Simplify Autodiff Handling of `rlib` Dependencies ### Summary: Resolves the two FIXMEs left in rust-lang/rust#149033, per @bjorn3 guidance in [the discussion](https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880). Closes rust-lang/rust#149164 r? @ZuseZ4 cc @bjorn3
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_data_structures::thin_vec::ThinVec;
|
||||
use rustc_hir::attrs::RustcAutodiff;
|
||||
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
|
||||
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
|
||||
use rustc_middle::{bug, ty};
|
||||
use rustc_target::callconv::PassMode;
|
||||
use tracing::debug;
|
||||
@@ -18,25 +18,23 @@
|
||||
|
||||
pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
instance: Instance<'tcx>,
|
||||
fn_ptr_ty: Ty<'tcx>,
|
||||
typing_env: TypingEnv<'tcx>,
|
||||
da: &mut ThinVec<DiffActivity>,
|
||||
) {
|
||||
let fn_ty = instance.ty(tcx, typing_env);
|
||||
|
||||
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
|
||||
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
|
||||
}
|
||||
|
||||
// We don't actually pass the types back into the type system.
|
||||
// All we do is decide how to handle the arguments.
|
||||
let sig = fn_ty.fn_sig(tcx).skip_binder();
|
||||
let fn_sig = fn_ptr_ty.fn_sig(tcx);
|
||||
let sig = fn_sig.skip_binder();
|
||||
|
||||
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
|
||||
let Ok(fn_abi) =
|
||||
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
|
||||
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
|
||||
else {
|
||||
bug!("failed to get fn_abi of instance with empty varargs");
|
||||
bug!("failed to get fn_abi of fn_ptr with empty varargs");
|
||||
};
|
||||
|
||||
let mut new_activities = vec![];
|
||||
|
||||
@@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
let ret_ty = sig.output();
|
||||
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
|
||||
|
||||
// Get source, diff, and attrs
|
||||
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
|
||||
ty::FnDef(def_id, source_params) => (def_id, source_params),
|
||||
_ => bug!("invalid autodiff intrinsic args"),
|
||||
};
|
||||
|
||||
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
|
||||
Ok(Some(instance)) => instance,
|
||||
Ok(None) => bug!(
|
||||
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
|
||||
source_id,
|
||||
source_args
|
||||
),
|
||||
Err(_) => {
|
||||
// An error has already been emitted
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
|
||||
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
|
||||
bug!("could not find source function")
|
||||
};
|
||||
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
|
||||
let fn_to_diff = args[0].immediate();
|
||||
|
||||
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
|
||||
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
|
||||
@@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
|
||||
adjust_activity_to_abi(
|
||||
tcx,
|
||||
fn_source,
|
||||
source_fn_ptr_ty,
|
||||
TypingEnv::fully_monomorphized(),
|
||||
&mut diff_attrs.input_activity,
|
||||
);
|
||||
|
||||
let fnc_tree =
|
||||
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
|
||||
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
|
||||
|
||||
// Build body
|
||||
generate_enzyme_call(
|
||||
|
||||
Reference in New Issue
Block a user