mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Auto merge of #153379 - TKanX:refactor/149164-simplify-autodiff-rlib, r=ZuseZ4
refactor(autodiff): Simplify Autodiff Handling of `rlib` Dependencies ### Summary: Resolves the two FIXMEs left in rust-lang/rust#149033, per @bjorn3 guidance in [the discussion](https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880). Closes rust-lang/rust#149164 r? @ZuseZ4 cc @bjorn3
This commit is contained in:
@@ -20,7 +20,7 @@ mod llvm_enzyme {
|
||||
};
|
||||
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||
use rustc_hir::attrs::RustcAutodiff;
|
||||
use rustc_span::{Ident, Span, Symbol, sym};
|
||||
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
||||
use thin_vec::{ThinVec, thin_vec};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
@@ -197,7 +197,7 @@ pub(crate) fn expand_reverse(
|
||||
/// }
|
||||
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
|
||||
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
|
||||
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
|
||||
/// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))
|
||||
/// }
|
||||
/// ```
|
||||
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
||||
@@ -326,6 +326,7 @@ pub(crate) fn expand_with_mode(
|
||||
primal,
|
||||
first_ident(&meta_item_vec[0]),
|
||||
span,
|
||||
&sig,
|
||||
&d_sig,
|
||||
&generics,
|
||||
is_impl,
|
||||
@@ -496,18 +497,62 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
|
||||
|
||||
// Generate `autodiff` intrinsic call
|
||||
// ```
|
||||
// std::intrinsics::autodiff(source, diff, (args))
|
||||
// std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))
|
||||
// ```
|
||||
fn call_autodiff(
|
||||
ecx: &ExtCtxt<'_>,
|
||||
primal: Ident,
|
||||
diff: Ident,
|
||||
span: Span,
|
||||
p_sig: &FnSig,
|
||||
d_sig: &FnSig,
|
||||
generics: &Generics,
|
||||
is_impl: bool,
|
||||
) -> rustc_ast::Stmt {
|
||||
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
|
||||
|
||||
let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));
|
||||
let fn_ptr_params: ThinVec<ast::Param> = p_sig
|
||||
.decl
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|param| {
|
||||
let ty = match ¶m.ty.kind {
|
||||
TyKind::ImplicitSelf => self_ty(),
|
||||
TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(
|
||||
span,
|
||||
TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),
|
||||
),
|
||||
TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {
|
||||
ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))
|
||||
}
|
||||
_ => param.ty.clone(),
|
||||
};
|
||||
ast::Param {
|
||||
attrs: ast::AttrVec::new(),
|
||||
ty,
|
||||
pat: Box::new(ecx.pat_wild(span)),
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
span,
|
||||
is_placeholder: false,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let fn_ptr_ty = ecx.ty(
|
||||
span,
|
||||
TyKind::FnPtr(Box::new(ast::FnPtrTy {
|
||||
safety: p_sig.header.safety,
|
||||
ext: p_sig.header.ext,
|
||||
generic_params: ThinVec::new(),
|
||||
decl: Box::new(ast::FnDecl {
|
||||
inputs: fn_ptr_params,
|
||||
output: p_sig.decl.output.clone(),
|
||||
}),
|
||||
decl_span: span,
|
||||
})),
|
||||
);
|
||||
let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));
|
||||
|
||||
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
|
||||
|
||||
let tuple_expr = ecx.expr_tuple(
|
||||
@@ -529,7 +574,7 @@ fn call_autodiff(
|
||||
let call_expr = ecx.expr_call(
|
||||
span,
|
||||
ecx.expr_path(enzyme_path),
|
||||
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
|
||||
vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),
|
||||
);
|
||||
|
||||
ecx.stmt_expr(call_expr)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_data_structures::thin_vec::ThinVec;
|
||||
use rustc_hir::attrs::RustcAutodiff;
|
||||
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
|
||||
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
|
||||
use rustc_middle::{bug, ty};
|
||||
use rustc_target::callconv::PassMode;
|
||||
use tracing::debug;
|
||||
@@ -18,25 +18,23 @@
|
||||
|
||||
pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
instance: Instance<'tcx>,
|
||||
fn_ptr_ty: Ty<'tcx>,
|
||||
typing_env: TypingEnv<'tcx>,
|
||||
da: &mut ThinVec<DiffActivity>,
|
||||
) {
|
||||
let fn_ty = instance.ty(tcx, typing_env);
|
||||
|
||||
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
|
||||
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
|
||||
}
|
||||
|
||||
// We don't actually pass the types back into the type system.
|
||||
// All we do is decide how to handle the arguments.
|
||||
let sig = fn_ty.fn_sig(tcx).skip_binder();
|
||||
let fn_sig = fn_ptr_ty.fn_sig(tcx);
|
||||
let sig = fn_sig.skip_binder();
|
||||
|
||||
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
|
||||
let Ok(fn_abi) =
|
||||
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
|
||||
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
|
||||
else {
|
||||
bug!("failed to get fn_abi of instance with empty varargs");
|
||||
bug!("failed to get fn_abi of fn_ptr with empty varargs");
|
||||
};
|
||||
|
||||
let mut new_activities = vec![];
|
||||
|
||||
@@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
let ret_ty = sig.output();
|
||||
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
|
||||
|
||||
// Get source, diff, and attrs
|
||||
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
|
||||
ty::FnDef(def_id, source_params) => (def_id, source_params),
|
||||
_ => bug!("invalid autodiff intrinsic args"),
|
||||
};
|
||||
|
||||
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
|
||||
Ok(Some(instance)) => instance,
|
||||
Ok(None) => bug!(
|
||||
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
|
||||
source_id,
|
||||
source_args
|
||||
),
|
||||
Err(_) => {
|
||||
// An error has already been emitted
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
|
||||
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
|
||||
bug!("could not find source function")
|
||||
};
|
||||
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
|
||||
let fn_to_diff = args[0].immediate();
|
||||
|
||||
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
|
||||
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
|
||||
@@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
|
||||
adjust_activity_to_abi(
|
||||
tcx,
|
||||
fn_source,
|
||||
source_fn_ptr_ty,
|
||||
TypingEnv::fully_monomorphized(),
|
||||
&mut diff_attrs.input_activity,
|
||||
);
|
||||
|
||||
let fnc_tree =
|
||||
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
|
||||
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
|
||||
|
||||
// Build body
|
||||
generate_enzyme_call(
|
||||
|
||||
@@ -35,11 +35,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
|
||||
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if find_attr!(tcx, def_id, RustcIntrinsic) {
|
||||
// Intrinsic fallback bodies are always cross-crate inlineable.
|
||||
// To ensure that the MIR inliner doesn't cluelessly try to inline fallback
|
||||
|
||||
@@ -205,8 +205,6 @@
|
||||
//! this is not implemented however: a mono item will be produced
|
||||
//! regardless of whether it is actually needed or not.
|
||||
|
||||
mod autodiff;
|
||||
|
||||
use std::cell::OnceCell;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
@@ -240,7 +238,6 @@
|
||||
use rustc_span::{DUMMY_SP, Span};
|
||||
use tracing::{debug, instrument, trace};
|
||||
|
||||
use crate::collector::autodiff::collect_autodiff_fn;
|
||||
use crate::errors::{
|
||||
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
|
||||
NoOptimizedMir, RecursionLimit,
|
||||
@@ -990,8 +987,6 @@ fn visit_instance_use<'tcx>(
|
||||
return;
|
||||
}
|
||||
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
|
||||
collect_autodiff_fn(tcx, instance, intrinsic, output);
|
||||
|
||||
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
|
||||
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
|
||||
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
|
||||
|
||||
use crate::collector::{MonoItems, create_fn_mono_item};
|
||||
|
||||
// Here, we force both primal and diff function to be collected in
|
||||
// mono so this does not interfere in `autodiff` intrinsics
|
||||
// codegen process. If they are unused, LLVM will remove them when
|
||||
// compiling with O3.
|
||||
// FIXME(autodiff): Remove this whole file, as per discussion in
|
||||
// https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
|
||||
pub(crate) fn collect_autodiff_fn<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
instance: ty::Instance<'tcx>,
|
||||
intrinsic: IntrinsicDef,
|
||||
output: &mut MonoItems<'tcx>,
|
||||
) {
|
||||
if intrinsic.name != rustc_span::sym::autodiff {
|
||||
return;
|
||||
};
|
||||
|
||||
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
|
||||
}
|
||||
|
||||
fn collect_autodiff_fn_from_arg<'tcx>(
|
||||
arg: GenericArg<'tcx>,
|
||||
tcx: TyCtxt<'tcx>,
|
||||
output: &mut MonoItems<'tcx>,
|
||||
) {
|
||||
let (instance, span) = match arg.kind() {
|
||||
ty::GenericArgKind::Type(ty) => match *ty.kind() {
|
||||
ty::FnDef(def_id, substs) => {
|
||||
let span = tcx.def_span(def_id);
|
||||
let instance = ty::Instance::expect_resolve(
|
||||
tcx,
|
||||
ty::TypingEnv::non_body_analysis(tcx, def_id),
|
||||
def_id,
|
||||
substs,
|
||||
span,
|
||||
);
|
||||
|
||||
(instance, span)
|
||||
}
|
||||
_ => bug!("expected autodiff function"),
|
||||
},
|
||||
_ => bug!("expected type when matching autodiff arg"),
|
||||
};
|
||||
|
||||
output.push(create_fn_mono_item(tcx, instance, span));
|
||||
}
|
||||
@@ -35,7 +35,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
|
||||
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
|
||||
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
|
||||
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df1::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||
@@ -43,7 +44,8 @@ pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
|
||||
::core::intrinsics::autodiff(f2::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df2::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
@@ -51,27 +53,33 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
|
||||
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df3::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f4() {}
|
||||
#[rustc_autodiff(Forward, 1, None)]
|
||||
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
|
||||
pub fn df4() -> () {
|
||||
::core::intrinsics::autodiff(f4::<> as fn(), df4::<>, ())
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f5(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
|
||||
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
|
||||
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df5_y::<>, (x, y, by_0))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
|
||||
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df5_x::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
|
||||
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df5_rev::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
struct DoesNotImplDefault;
|
||||
#[rustc_autodiff]
|
||||
@@ -80,13 +88,14 @@ pub fn f6() -> DoesNotImplDefault {
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Const)]
|
||||
pub fn df6() -> DoesNotImplDefault {
|
||||
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
|
||||
::core::intrinsics::autodiff(f6::<> as fn() -> DoesNotImplDefault,
|
||||
df6::<>, ())
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f7(x: f32) -> () {}
|
||||
#[rustc_autodiff(Forward, 1, Const, None)]
|
||||
pub fn df7(x: f32) -> () {
|
||||
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
|
||||
::core::intrinsics::autodiff(f7::<> as fn(_: f32) -> (), df7::<>, (x,))
|
||||
}
|
||||
#[no_mangle]
|
||||
#[rustc_autodiff]
|
||||
@@ -94,29 +103,32 @@ fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Forward, 4, Dual, Dual)]
|
||||
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 5usize] {
|
||||
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
|
||||
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_3::<>,
|
||||
(x, bx_0, bx_1, bx_2, bx_3))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
|
||||
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 4usize] {
|
||||
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
|
||||
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_2::<>,
|
||||
(x, bx_0, bx_1, bx_2, bx_3))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
|
||||
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
|
||||
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_1::<>,
|
||||
(x, bx_0))
|
||||
}
|
||||
pub fn f9() {
|
||||
#[rustc_autodiff]
|
||||
fn inner(x: f32) -> f32 { x * x }
|
||||
#[rustc_autodiff(Forward, 1, Dual, Dual)]
|
||||
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
|
||||
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
|
||||
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
|
||||
d_inner_2::<>, (x, bx_0))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
|
||||
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
|
||||
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
|
||||
d_inner_1::<>, (x, bx_0))
|
||||
}
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
@@ -124,6 +136,7 @@ pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
|
||||
pub fn d_square<T: std::ops::Mul<Output = T> +
|
||||
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
|
||||
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
|
||||
::core::intrinsics::autodiff(f10::<T> as fn(_: &T) -> T, d_square::<T>,
|
||||
(x, dx_0, dret))
|
||||
}
|
||||
fn main() {}
|
||||
|
||||
@@ -28,32 +28,37 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
|
||||
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df1::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f2() {}
|
||||
#[rustc_autodiff(Reverse, 1, None)]
|
||||
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
|
||||
pub fn df2() { ::core::intrinsics::autodiff(f2::<> as fn(), df2::<>, ()) }
|
||||
#[rustc_autodiff]
|
||||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
|
||||
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
|
||||
df3::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
enum Foo { Reverse, }
|
||||
use Foo::Reverse;
|
||||
#[rustc_autodiff]
|
||||
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Reverse, 1, Const, None)]
|
||||
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
|
||||
pub fn df4(x: f32) {
|
||||
::core::intrinsics::autodiff(f4::<> as fn(_: f32), df4::<>, (x,))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
pub fn f5(x: *const f32, y: &f32) {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
|
||||
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
|
||||
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
|
||||
::core::intrinsics::autodiff(f5::<> as fn(_: *const f32, _: &f32),
|
||||
df5::<>, (x, dx_0, y, dy_0))
|
||||
}
|
||||
fn main() {}
|
||||
|
||||
@@ -30,7 +30,7 @@ impl MyTrait for Foo {
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
|
||||
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
|
||||
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
|
||||
(self, x, dret))
|
||||
::core::intrinsics::autodiff(Self::f::<> as
|
||||
fn(_: &Self, _: f64) -> f64, Self::df::<>, (self, x, dret))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user