mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Allow applying autodiff macros to trait functions.
It will use enzyme to generate a default derivative implementation, which can be overwritten by the user.
This commit is contained in:
@@ -24,6 +24,7 @@ impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
|
||||
Allow(Target::Fn),
|
||||
Allow(Target::Method(MethodKind::Inherent)),
|
||||
Allow(Target::Method(MethodKind::Trait { body: true })),
|
||||
Allow(Target::Method(MethodKind::Trait { body: false })),
|
||||
Allow(Target::Method(MethodKind::TraitImpl)),
|
||||
]);
|
||||
const TEMPLATE: AttributeTemplate = template!(
|
||||
|
||||
@@ -224,16 +224,18 @@ pub(crate) fn expand_with_mode(
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Annotatable::AssocItem(assoc_item, Impl { of_trait: _ }) => match &assoc_item.kind {
|
||||
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
|
||||
assoc_item.vis.clone(),
|
||||
sig.clone(),
|
||||
ident.clone(),
|
||||
generics.clone(),
|
||||
true,
|
||||
)),
|
||||
_ => None,
|
||||
},
|
||||
Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => {
|
||||
match &assoc_item.kind {
|
||||
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
|
||||
assoc_item.vis.clone(),
|
||||
sig.clone(),
|
||||
ident.clone(),
|
||||
generics.clone(),
|
||||
true,
|
||||
)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}) else {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
@@ -393,14 +395,14 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
|
||||
}
|
||||
Annotatable::Item(iitem.clone())
|
||||
}
|
||||
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
|
||||
Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => {
|
||||
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
||||
assoc_item.attrs.push(attr);
|
||||
}
|
||||
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
|
||||
has_inline_never = true;
|
||||
}
|
||||
Annotatable::AssocItem(assoc_item.clone(), i)
|
||||
Annotatable::AssocItem(assoc_item.clone(), ctxt)
|
||||
}
|
||||
Annotatable::Stmt(ref mut stmt) => {
|
||||
match stmt.kind {
|
||||
@@ -441,7 +443,7 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
|
||||
}
|
||||
|
||||
let d_annotatable = match &item {
|
||||
Annotatable::AssocItem(_, _) => {
|
||||
Annotatable::AssocItem(_, ctxt) => {
|
||||
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
|
||||
let d_fn = Box::new(ast::AssocItem {
|
||||
attrs: d_attrs,
|
||||
@@ -451,7 +453,7 @@ fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
|
||||
kind: assoc_item,
|
||||
tokens: None,
|
||||
});
|
||||
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
|
||||
Annotatable::AssocItem(d_fn, *ctxt)
|
||||
}
|
||||
Annotatable::Item(_) => {
|
||||
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
|
||||
|
||||
@@ -149,14 +149,14 @@ pub fn expect_item(self) -> Box<ast::Item> {
|
||||
pub fn expect_trait_item(self) -> Box<ast::AssocItem> {
|
||||
match self {
|
||||
Annotatable::AssocItem(i, AssocCtxt::Trait) => i,
|
||||
_ => panic!("expected Item"),
|
||||
_ => panic!("expected trait item"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expect_impl_item(self) -> Box<ast::AssocItem> {
|
||||
match self {
|
||||
Annotatable::AssocItem(i, AssocCtxt::Impl { .. }) => i,
|
||||
_ => panic!("expected Item"),
|
||||
_ => panic!("expected impl item"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
// Just check it does not crash for now
|
||||
// CHECK: ;
|
||||
#![feature(autodiff)]
|
||||
#![feature(core_intrinsics)]
|
||||
#![feature(rustc_attrs)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
struct Foo {
|
||||
a: f64,
|
||||
}
|
||||
|
||||
trait MyTrait {
|
||||
#[rustc_autodiff]
|
||||
fn f(&self, x: f64) -> f64;
|
||||
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
|
||||
fn df(&self, x: f64, seed: f64) -> (f64, f64) {
|
||||
std::hint::black_box(seed);
|
||||
std::hint::black_box(x);
|
||||
::std::intrinsics::autodiff(
|
||||
Self::f as for<'a> fn(&'a Self, _: f64) -> f64,
|
||||
Self::df,
|
||||
(self, x, seed),
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl MyTrait for Foo {
|
||||
fn f(&self, x: f64) -> f64 {
|
||||
x.sin()
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let foo = Foo { a: 3.0f64 };
|
||||
dbg!(foo.df(2.0, 1.0));
|
||||
dbg!(2.0_f64.cos());
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
// Just check it does not crash for now
|
||||
// CHECK: ;
|
||||
#![feature(autodiff)]
|
||||
#![feature(core_intrinsics)]
|
||||
#![feature(rustc_attrs)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
struct Foo {
|
||||
a: f64,
|
||||
}
|
||||
|
||||
trait MyTrait {
|
||||
#[autodiff_reverse(df, Const, Active, Active)]
|
||||
fn f(&self, x: f64) -> f64;
|
||||
}
|
||||
|
||||
impl MyTrait for Foo {
|
||||
fn f(&self, x: f64) -> f64 {
|
||||
x.sin()
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let foo = Foo { a: 3.0f64 };
|
||||
dbg!(foo.df(2.0, 1.0));
|
||||
dbg!(2.0_f64.cos());
|
||||
}
|
||||
Reference in New Issue
Block a user