Allow applying autodiff macros to trait functions.

It will use enzyme to generate a default derivative implementation,
which can be overwritten by the user.
This commit is contained in:
Manuel Drehwald
2026-03-23 11:27:06 +01:00
parent a63150b9cb
commit 26c9f7255a
5 changed files with 94 additions and 16 deletions
@@ -24,6 +24,7 @@ impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
Allow(Target::Fn),
Allow(Target::Method(MethodKind::Inherent)),
Allow(Target::Method(MethodKind::Trait { body: true })),
Allow(Target::Method(MethodKind::Trait { body: false })),
Allow(Target::Method(MethodKind::TraitImpl)),
]);
const TEMPLATE: AttributeTemplate = template!(
+16 -14
View File
@@ -224,16 +224,18 @@ pub(crate) fn expand_with_mode(
}
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { of_trait: _ }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
assoc_item.vis.clone(),
sig.clone(),
ident.clone(),
generics.clone(),
true,
)),
_ => None,
},
Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => {
match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
assoc_item.vis.clone(),
sig.clone(),
ident.clone(),
generics.clone(),
true,
)),
_ => None,
}
}
_ => None,
}) else {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -393,14 +395,14 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
}
Annotatable::Item(iitem.clone())
}
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::AssocItem(assoc_item.clone(), i)
Annotatable::AssocItem(assoc_item.clone(), ctxt)
}
Annotatable::Stmt(ref mut stmt) => {
match stmt.kind {
@@ -441,7 +443,7 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
}
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
Annotatable::AssocItem(_, ctxt) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
let d_fn = Box::new(ast::AssocItem {
attrs: d_attrs,
@@ -451,7 +453,7 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
kind: assoc_item,
tokens: None,
});
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
Annotatable::AssocItem(d_fn, *ctxt)
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
+2 -2
View File
@@ -149,14 +149,14 @@ pub fn expect_item(self) -> Box<ast::Item> {
pub fn expect_trait_item(self) -> Box<ast::AssocItem> {
match self {
Annotatable::AssocItem(i, AssocCtxt::Trait) => i,
_ => panic!("expected Item"),
_ => panic!("expected trait item"),
}
}
pub fn expect_impl_item(self) -> Box<ast::AssocItem> {
match self {
Annotatable::AssocItem(i, AssocCtxt::Impl { .. }) => i,
_ => panic!("expected Item"),
_ => panic!("expected impl item"),
}
}
+43
View File
@@ -0,0 +1,43 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
// Just check it does not crash for now
// CHECK: ;
#![feature(autodiff)]
#![feature(core_intrinsics)]
#![feature(rustc_attrs)]
use std::autodiff::autodiff_reverse;
struct Foo {
a: f64,
}
trait MyTrait {
#[rustc_autodiff]
fn f(&self, x: f64) -> f64;
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
fn df(&self, x: f64, seed: f64) -> (f64, f64) {
std::hint::black_box(seed);
std::hint::black_box(x);
::std::intrinsics::autodiff(
Self::f as for<'a> fn(&'a Self, _: f64) -> f64,
Self::df,
(self, x, seed),
)
}
}
impl MyTrait for Foo {
fn f(&self, x: f64) -> f64 {
x.sin()
}
}
fn main() {
let foo = Foo { a: 3.0f64 };
dbg!(foo.df(2.0, 1.0));
dbg!(2.0_f64.cos());
}
+32
View File
@@ -0,0 +1,32 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
// Just check it does not crash for now
// CHECK: ;
#![feature(autodiff)]
#![feature(core_intrinsics)]
#![feature(rustc_attrs)]
use std::autodiff::autodiff_reverse;
struct Foo {
a: f64,
}
trait MyTrait {
#[autodiff_reverse(df, Const, Active, Active)]
fn f(&self, x: f64) -> f64;
}
impl MyTrait for Foo {
fn f(&self, x: f64) -> f64 {
x.sin()
}
}
fn main() {
let foo = Foo { a: 3.0f64 };
dbg!(foo.df(2.0, 1.0));
dbg!(2.0_f64.cos());
}