diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 259ee25407a5..30391e74480f 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -20,7 +20,7 @@ mod llvm_enzyme { }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_hir::attrs::RustcAutodiff; - use rustc_span::{Ident, Span, Symbol, sym}; + use rustc_span::{Ident, Span, Symbol, kw, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; @@ -197,7 +197,7 @@ pub(crate) fn expand_reverse( /// } /// #[rustc_autodiff(Reverse, Duplicated, Active)] /// fn cos_box(x: &Box, dx: &mut Box, dret: f32) -> f32 { - /// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret)) + /// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret)) /// } /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked @@ -326,6 +326,7 @@ pub(crate) fn expand_with_mode( primal, first_ident(&meta_item_vec[0]), span, + &sig, &d_sig, &generics, is_impl, @@ -496,18 +497,62 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { // Generate `autodiff` intrinsic call // ``` - // std::intrinsics::autodiff(source, diff, (args)) + // std::intrinsics::autodiff(source as fn(..) -> .., diff, (args)) // ``` fn call_autodiff( ecx: &ExtCtxt<'_>, primal: Ident, diff: Ident, span: Span, + p_sig: &FnSig, d_sig: &FnSig, generics: &Generics, is_impl: bool, ) -> rustc_ast::Stmt { let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + + let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper))); + let fn_ptr_params: ThinVec = p_sig + .decl + .inputs + .iter() + .map(|param| { + let ty = match ¶m.ty.kind { + TyKind::ImplicitSelf => self_ty(), + TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty( + span, + TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }), + ), + TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => { + ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl })) + } + _ => param.ty.clone(), + }; + ast::Param { + attrs: ast::AttrVec::new(), + ty, + pat: Box::new(ecx.pat_wild(span)), + id: ast::DUMMY_NODE_ID, + span, + is_placeholder: false, + } + }) + .collect(); + let fn_ptr_ty = ecx.ty( + span, + TyKind::FnPtr(Box::new(ast::FnPtrTy { + safety: p_sig.header.safety, + ext: p_sig.header.ext, + generic_params: ThinVec::new(), + decl: Box::new(ast::FnDecl { + inputs: fn_ptr_params, + output: p_sig.decl.output.clone(), + }), + decl_span: span, + })), + ); + let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty)); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); let tuple_expr = ecx.expr_tuple( @@ -529,7 +574,7 @@ fn call_autodiff( let call_expr = ecx.expr_call( span, ecx.expr_path(enzyme_path), - vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(), ); ecx.stmt_expr(call_expr) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index f8f6439a7b0e..1cefdaae5ebd 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -6,7 +6,7 @@ use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_data_structures::thin_vec::ThinVec; use rustc_hir::attrs::RustcAutodiff; -use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv}; +use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, ty}; use rustc_target::callconv::PassMode; use tracing::debug; @@ -18,25 +18,23 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, - instance: Instance<'tcx>, + fn_ptr_ty: Ty<'tcx>, typing_env: TypingEnv<'tcx>, da: &mut ThinVec, ) { - let fn_ty = instance.ty(tcx, typing_env); - - if !matches!(fn_ty.kind(), ty::FnDef(..)) { - bug!("expected fn def for autodiff, got {:?}", fn_ty); + if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) { + bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty); } // We don't actually pass the types back into the type system. // All we do is decide how to handle the arguments. - let sig = fn_ty.fn_sig(tcx).skip_binder(); + let fn_sig = fn_ptr_ty.fn_sig(tcx); + let sig = fn_sig.skip_binder(); // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions - let Ok(fn_abi) = - tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty()))) + let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty()))) else { - bug!("failed to get fn_abi of instance with empty varargs"); + bug!("failed to get fn_abi of fn_ptr with empty varargs"); }; let mut new_activities = vec![]; diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 8ceb7ba29737..bd90f596eb3f 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>( let ret_ty = sig.output(); let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); - // Get source, diff, and attrs - let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, source_params) => (def_id, source_params), - _ => bug!("invalid autodiff intrinsic args"), - }; - - let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) { - Ok(Some(instance)) => instance, - Ok(None) => bug!( - "could not resolve ({:?}, {:?}) to a specific autodiff instance", - source_id, - source_args - ), - Err(_) => { - // An error has already been emitted - return; - } - }; - - let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); - let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else { - bug!("could not find source function") - }; + let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0]; + let fn_to_diff = args[0].immediate(); let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { ty::FnDef(def_id, diff_args) => (def_id, diff_args), @@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>( adjust_activity_to_abi( tcx, - fn_source, + source_fn_ptr_ty, TypingEnv::fully_monomorphized(), &mut diff_attrs.input_activity, ); - let fnc_tree = - rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized())); + let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty); // Build body generate_enzyme_call( diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 19ffffdd1eca..53aa5f450dbb 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -35,11 +35,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return true; } - // FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 - if find_attr!(tcx, def_id, RustcAutodiff(..)) { - return true; - } - if find_attr!(tcx, def_id, RustcIntrinsic) { // Intrinsic fallback bodies are always cross-crate inlineable. // To ensure that the MIR inliner doesn't cluelessly try to inline fallback diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 3aa55cc8eb9f..3b1b5ba9673a 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -205,8 +205,6 @@ //! this is not implemented however: a mono item will be produced //! regardless of whether it is actually needed or not. -mod autodiff; - use std::cell::OnceCell; use std::ops::ControlFlow; @@ -240,7 +238,6 @@ use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; -use crate::collector::autodiff::collect_autodiff_fn; use crate::errors::{ self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm, NoOptimizedMir, RecursionLimit, @@ -990,8 +987,6 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { - collect_autodiff_fn(tcx, instance, intrinsic, output); - if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs deleted file mode 100644 index 67d4b8c8afff..000000000000 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ /dev/null @@ -1,50 +0,0 @@ -use rustc_middle::bug; -use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt}; - -use crate::collector::{MonoItems, create_fn_mono_item}; - -// Here, we force both primal and diff function to be collected in -// mono so this does not interfere in `autodiff` intrinsics -// codegen process. If they are unused, LLVM will remove them when -// compiling with O3. -// FIXME(autodiff): Remove this whole file, as per discussion in -// https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 -pub(crate) fn collect_autodiff_fn<'tcx>( - tcx: TyCtxt<'tcx>, - instance: ty::Instance<'tcx>, - intrinsic: IntrinsicDef, - output: &mut MonoItems<'tcx>, -) { - if intrinsic.name != rustc_span::sym::autodiff { - return; - }; - - collect_autodiff_fn_from_arg(instance.args[0], tcx, output); -} - -fn collect_autodiff_fn_from_arg<'tcx>( - arg: GenericArg<'tcx>, - tcx: TyCtxt<'tcx>, - output: &mut MonoItems<'tcx>, -) { - let (instance, span) = match arg.kind() { - ty::GenericArgKind::Type(ty) => match *ty.kind() { - ty::FnDef(def_id, substs) => { - let span = tcx.def_span(def_id); - let instance = ty::Instance::expect_resolve( - tcx, - ty::TypingEnv::non_body_analysis(tcx, def_id), - def_id, - substs, - span, - ); - - (instance, span) - } - _ => bug!("expected autodiff function"), - }, - _ => bug!("expected type when matching autodiff arg"), - }; - - output.push(create_fn_mono_item(tcx, instance, span)); -} diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index c64da3f60884..746754637f5c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -35,7 +35,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 { } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - ::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y)) + ::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64, + df1::<>, (x, bx_0, y)) } #[rustc_autodiff] pub fn f2(x: &[f64], y: f64) -> f64 { @@ -43,7 +44,8 @@ pub fn f2(x: &[f64], y: f64) -> f64 { } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - ::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y)) + ::core::intrinsics::autodiff(f2::<> as fn(_: &[f64], _: f64) -> f64, + df2::<>, (x, bx_0, y)) } #[rustc_autodiff] pub fn f3(x: &[f64], y: f64) -> f64 { @@ -51,27 +53,33 @@ pub fn f3(x: &[f64], y: f64) -> f64 { } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - ::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y)) + ::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64, + df3::<>, (x, bx_0, y)) } #[rustc_autodiff] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) } +pub fn df4() -> () { + ::core::intrinsics::autodiff(f4::<> as fn(), df4::<>, ()) +} #[rustc_autodiff] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - ::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0)) + ::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64, + df5_y::<>, (x, y, by_0)) } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - ::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y)) + ::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64, + df5_x::<>, (x, bx_0, y)) } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - ::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret)) + ::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64, + df5_rev::<>, (x, dx_0, y, dret)) } struct DoesNotImplDefault; #[rustc_autodiff] @@ -80,13 +88,14 @@ pub fn f6() -> DoesNotImplDefault { } #[rustc_autodiff(Forward, 1, Const)] pub fn df6() -> DoesNotImplDefault { - ::core::intrinsics::autodiff(f6::<>, df6::<>, ()) + ::core::intrinsics::autodiff(f6::<> as fn() -> DoesNotImplDefault, + df6::<>, ()) } #[rustc_autodiff] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] pub fn df7(x: f32) -> () { - ::core::intrinsics::autodiff(f7::<>, df7::<>, (x,)) + ::core::intrinsics::autodiff(f7::<> as fn(_: f32) -> (), df7::<>, (x,)) } #[no_mangle] #[rustc_autodiff] @@ -94,29 +103,32 @@ fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) -> [f32; 5usize] { - ::core::intrinsics::autodiff(f8::<>, f8_3::<>, + ::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_3::<>, (x, bx_0, bx_1, bx_2, bx_3)) } #[rustc_autodiff(Forward, 4, Dual, DualOnly)] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) -> [f32; 4usize] { - ::core::intrinsics::autodiff(f8::<>, f8_2::<>, + ::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_2::<>, (x, bx_0, bx_1, bx_2, bx_3)) } #[rustc_autodiff(Forward, 1, Dual, DualOnly)] fn f8_1(x: &f32, bx_0: &f32) -> f32 { - ::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0)) + ::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_1::<>, + (x, bx_0)) } pub fn f9() { #[rustc_autodiff] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - ::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0)) + ::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32, + d_inner_2::<>, (x, bx_0)) } #[rustc_autodiff(Forward, 1, Dual, DualOnly)] fn d_inner_1(x: f32, bx_0: f32) -> f32 { - ::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0)) + ::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32, + d_inner_1::<>, (x, bx_0)) } } #[rustc_autodiff] @@ -124,6 +136,7 @@ pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] pub fn d_square + Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - ::core::intrinsics::autodiff(f10::, d_square::, (x, dx_0, dret)) + ::core::intrinsics::autodiff(f10:: as fn(_: &T) -> T, d_square::, + (x, dx_0, dret)) } fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index 61ab121b31bc..e2088e0ac13d 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -28,32 +28,37 @@ pub fn f1(x: &[f64], y: f64) -> f64 { } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - ::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret)) + ::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64, + df1::<>, (x, dx_0, y, dret)) } #[rustc_autodiff] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) } +pub fn df2() { ::core::intrinsics::autodiff(f2::<> as fn(), df2::<>, ()) } #[rustc_autodiff] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - ::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret)) + ::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64, + df3::<>, (x, dx_0, y, dret)) } enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) } +pub fn df4(x: f32) { + ::core::intrinsics::autodiff(f4::<> as fn(_: f32), df4::<>, (x,)) +} #[rustc_autodiff] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - ::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0)) + ::core::intrinsics::autodiff(f5::<> as fn(_: *const f32, _: &f32), + df5::<>, (x, dx_0, y, dy_0)) } fn main() {} diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index 1c83c66c8edf..d3a5a71b8bcb 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -30,7 +30,7 @@ impl MyTrait for Foo { } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] fn df(&self, x: f64, dret: f64) -> (f64, f64) { - ::core::intrinsics::autodiff(Self::f::<>, Self::df::<>, - (self, x, dret)) + ::core::intrinsics::autodiff(Self::f::<> as + fn(_: &Self, _: f64) -> f64, Self::df::<>, (self, x, dret)) } }