mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Port rustc_autodiff to the attribute parsers
This commit is contained in:
@@ -0,0 +1,117 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use rustc_ast::LitKind;
|
||||
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
|
||||
use rustc_feature::{AttributeTemplate, template};
|
||||
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
|
||||
use rustc_hir::{MethodKind, Target};
|
||||
use rustc_span::{Symbol, sym};
|
||||
use thin_vec::ThinVec;
|
||||
|
||||
use crate::attributes::prelude::Allow;
|
||||
use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser};
|
||||
use crate::context::{AcceptContext, Stage};
|
||||
use crate::parser::{ArgParser, MetaItemOrLitParser};
|
||||
use crate::target_checking::AllowedTargets;
|
||||
|
||||
pub(crate) struct RustcAutodiffParser;
|
||||
|
||||
impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
|
||||
const PATH: &[Symbol] = &[sym::rustc_autodiff];
|
||||
const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost;
|
||||
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
|
||||
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
|
||||
Allow(Target::Fn),
|
||||
Allow(Target::Method(MethodKind::Inherent)),
|
||||
Allow(Target::Method(MethodKind::Trait { body: true })),
|
||||
Allow(Target::Method(MethodKind::TraitImpl)),
|
||||
]);
|
||||
const TEMPLATE: AttributeTemplate = template!(
|
||||
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
|
||||
"https://doc.rust-lang.org/std/autodiff/index.html"
|
||||
);
|
||||
|
||||
fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
|
||||
let list = match args {
|
||||
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
|
||||
ArgParser::List(list) => list,
|
||||
ArgParser::NameValue(_) => {
|
||||
cx.expected_list_or_no_args(cx.attr_span);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let mut items = list.mixed().peekable();
|
||||
|
||||
// Parse name
|
||||
let Some(mode) = items.next() else {
|
||||
cx.expected_at_least_one_argument(list.span);
|
||||
return None;
|
||||
};
|
||||
let Some(mode) = mode.meta_item() else {
|
||||
cx.expected_identifier(mode.span());
|
||||
return None;
|
||||
};
|
||||
let Ok(()) = mode.args().no_args() else {
|
||||
cx.expected_identifier(mode.span());
|
||||
return None;
|
||||
};
|
||||
let Some(mode) = mode.path().word() else {
|
||||
cx.expected_identifier(mode.span());
|
||||
return None;
|
||||
};
|
||||
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
|
||||
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
|
||||
return None;
|
||||
};
|
||||
|
||||
// Parse width
|
||||
let width = if let Some(width) = items.peek()
|
||||
&& let MetaItemOrLitParser::Lit(width) = width
|
||||
&& let LitKind::Int(width, _) = width.kind
|
||||
&& let Ok(width) = width.0.try_into()
|
||||
{
|
||||
_ = items.next();
|
||||
width
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
// Parse activities
|
||||
let mut activities = ThinVec::new();
|
||||
for activity in items {
|
||||
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
|
||||
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
|
||||
return None;
|
||||
};
|
||||
let Ok(()) = activity.args().no_args() else {
|
||||
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
|
||||
return None;
|
||||
};
|
||||
let Some(activity) = activity.path().word() else {
|
||||
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
|
||||
return None;
|
||||
};
|
||||
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
|
||||
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
|
||||
return None;
|
||||
};
|
||||
|
||||
activities.push(activity);
|
||||
}
|
||||
let Some(ret_activity) = activities.pop() else {
|
||||
cx.expected_specific_argument(
|
||||
list.span.with_lo(list.span.hi()),
|
||||
DiffActivity::all_activities(),
|
||||
);
|
||||
return None;
|
||||
};
|
||||
|
||||
Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
|
||||
mode,
|
||||
width,
|
||||
input_activity: activities,
|
||||
ret_activity,
|
||||
}))))
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,7 @@
|
||||
mod prelude;
|
||||
|
||||
pub(crate) mod allow_unstable;
|
||||
pub(crate) mod autodiff;
|
||||
pub(crate) mod body;
|
||||
pub(crate) mod cfg;
|
||||
pub(crate) mod cfg_select;
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
use crate::AttributeParser;
|
||||
// Glob imports to avoid big, bitrotty import lists
|
||||
use crate::attributes::allow_unstable::*;
|
||||
use crate::attributes::autodiff::*;
|
||||
use crate::attributes::body::*;
|
||||
use crate::attributes::cfi_encoding::*;
|
||||
use crate::attributes::codegen_attrs::*;
|
||||
@@ -204,6 +205,7 @@ mod late {
|
||||
Single<ReexportTestHarnessMainParser>,
|
||||
Single<RustcAbiParser>,
|
||||
Single<RustcAllocatorZeroedVariantParser>,
|
||||
Single<RustcAutodiffParser>,
|
||||
Single<RustcBuiltinMacroParser>,
|
||||
Single<RustcDefPathParser>,
|
||||
Single<RustcDeprecatedSafe2024Parser>,
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange,
|
||||
};
|
||||
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
|
||||
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
|
||||
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
|
||||
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
|
||||
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
|
||||
@@ -15,6 +14,7 @@
|
||||
use rustc_data_structures::assert_matches;
|
||||
use rustc_hir as hir;
|
||||
use rustc_hir::def_id::LOCAL_CRATE;
|
||||
use rustc_hir::find_attr;
|
||||
use rustc_middle::mir::BinOp;
|
||||
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
|
||||
use rustc_middle::ty::offload_meta::OffloadMetadata;
|
||||
@@ -1367,7 +1367,9 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
|
||||
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
|
||||
|
||||
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
|
||||
let Some(Some(mut diff_attrs)) =
|
||||
find_attr!(tcx, fn_diff.def_id(), RustcAutodiff(attr) => attr.clone())
|
||||
else {
|
||||
bug!("could not find autodiff attrs")
|
||||
};
|
||||
|
||||
@@ -1389,7 +1391,7 @@ fn codegen_autodiff<'ll, 'tcx>(
|
||||
&diff_symbol,
|
||||
llret_ty,
|
||||
&val_arr,
|
||||
diff_attrs.clone(),
|
||||
&diff_attrs,
|
||||
result,
|
||||
fnc_tree,
|
||||
);
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use rustc_abi::{Align, ExternAbi};
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_ast::{LitKind, MetaItem, MetaItemInner};
|
||||
use rustc_hir::attrs::{
|
||||
AttributeKind, EiiImplResolution, InlineAttr, Linkage, RtsanSetting, UsedBy,
|
||||
};
|
||||
@@ -14,7 +10,6 @@
|
||||
};
|
||||
use rustc_middle::mir::mono::Visibility;
|
||||
use rustc_middle::query::Providers;
|
||||
use rustc_middle::span_bug;
|
||||
use rustc_middle::ty::{self as ty, TyCtxt};
|
||||
use rustc_session::lint;
|
||||
use rustc_session::parse::feature_err;
|
||||
@@ -614,116 +609,6 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Align> {
|
||||
tcx.codegen_fn_attrs(tcx.trait_item_of(def_id)?).alignment
|
||||
}
|
||||
|
||||
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
|
||||
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
|
||||
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
|
||||
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
|
||||
/// panic, unless we introduced a bug when parsing the autodiff macro.
|
||||
//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed.
|
||||
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
||||
#[allow(deprecated)]
|
||||
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
|
||||
|
||||
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();
|
||||
|
||||
// check for exactly one autodiff attribute on placeholder functions.
|
||||
// There should only be one, since we generate a new placeholder per ad macro.
|
||||
let attr = match &attrs[..] {
|
||||
[] => return None,
|
||||
[attr] => attr,
|
||||
_ => {
|
||||
span_bug!(attrs[1].span(), "cg_ssa: rustc_autodiff should only exist once per source");
|
||||
}
|
||||
};
|
||||
|
||||
let list = attr.meta_item_list().unwrap_or_default();
|
||||
|
||||
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
|
||||
if list.is_empty() {
|
||||
return Some(AutoDiffAttrs::source());
|
||||
}
|
||||
|
||||
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
|
||||
};
|
||||
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
|
||||
p1.segments.first().unwrap().ident
|
||||
} else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode");
|
||||
};
|
||||
|
||||
// parse mode
|
||||
let mode = match mode.as_str() {
|
||||
"Forward" => DiffMode::Forward,
|
||||
"Reverse" => DiffMode::Reverse,
|
||||
_ => {
|
||||
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
|
||||
}
|
||||
};
|
||||
|
||||
let width: u32 = match width_meta {
|
||||
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
|
||||
let w = p1.segments.first().unwrap().ident;
|
||||
match w.as_str().parse() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(w.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
}
|
||||
MetaItemInner::Lit(lit) => {
|
||||
if let LitKind::Int(val, _) = lit.kind {
|
||||
match val.get().try_into() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(lit.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
span_bug!(lit.span, "rustc_autodiff width should be an integer");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// First read the ret symbol from the attribute
|
||||
let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain the return activity");
|
||||
};
|
||||
let ret_symbol = p1.segments.first().unwrap().ident;
|
||||
|
||||
// Then parse it into an actual DiffActivity
|
||||
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
|
||||
span_bug!(ret_symbol.span, "invalid return activity");
|
||||
};
|
||||
|
||||
// Now parse all the intermediate (input) activities
|
||||
let mut arg_activities: Vec<DiffActivity> = vec![];
|
||||
for arg in input_activities {
|
||||
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p2, .. }) = arg {
|
||||
match p2.segments.first() {
|
||||
Some(x) => x.ident,
|
||||
None => {
|
||||
span_bug!(
|
||||
arg.span(),
|
||||
"rustc_autodiff attribute must contain the input activity"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
|
||||
};
|
||||
|
||||
match DiffActivity::from_str(arg_symbol.as_str()) {
|
||||
Ok(arg_activity) => arg_activities.push(arg_activity),
|
||||
Err(_) => {
|
||||
span_bug!(arg_symbol.span, "invalid input activity");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
|
||||
}
|
||||
|
||||
pub(crate) fn provide(providers: &mut Providers) {
|
||||
*providers = Providers {
|
||||
codegen_fn_attrs,
|
||||
|
||||
@@ -1286,6 +1286,9 @@ pub enum AttributeKind {
|
||||
/// Represents `#[rustc_as_ptr]` (used by the `dangling_pointers_from_temporaries` lint).
|
||||
RustcAsPtr(Span),
|
||||
|
||||
/// Represents `#[rustc_autodiff]`.
|
||||
RustcAutodiff(Option<Box<RustcAutodiff>>),
|
||||
|
||||
/// Represents `#[rustc_default_body_unstable]`.
|
||||
RustcBodyStability {
|
||||
stability: DefaultBodyStability,
|
||||
|
||||
@@ -102,6 +102,7 @@ pub fn encode_cross_crate(&self) -> EncodeCrossCrate {
|
||||
RustcAllowConstFnUnstable(..) => No,
|
||||
RustcAllowIncoherentImpl(..) => No,
|
||||
RustcAsPtr(..) => Yes,
|
||||
RustcAutodiff(..) => Yes,
|
||||
RustcBodyStability { .. } => No,
|
||||
RustcBuiltinMacro { .. } => Yes,
|
||||
RustcCaptureAnalysis => No,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use rustc_ast::ast::{Path, join_path_idents};
|
||||
use rustc_ast::attr::data_structures::CfgEntry;
|
||||
use rustc_ast::attr::version::RustcVersion;
|
||||
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
|
||||
use rustc_ast::token::{CommentKind, DocFragmentKind};
|
||||
use rustc_ast::{AttrId, AttrStyle, IntTy, UintTy};
|
||||
use rustc_ast_pretty::pp::Printer;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
use rustc_middle::query::Providers;
|
||||
use rustc_middle::ty::TyCtxt;
|
||||
use rustc_session::config::{InliningThreshold, OptLevel};
|
||||
use rustc_span::sym;
|
||||
|
||||
use crate::{inline, pass_manager as pm};
|
||||
|
||||
@@ -37,11 +36,7 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
|
||||
}
|
||||
|
||||
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
|
||||
#[allow(deprecated)]
|
||||
if tcx.has_attr(def_id, sym::autodiff_forward)
|
||||
|| tcx.has_attr(def_id, sym::autodiff_reverse)
|
||||
|| tcx.has_attr(def_id, sym::rustc_autodiff)
|
||||
{
|
||||
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@
|
||||
use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
|
||||
use rustc_trait_selection::infer::{TyCtxtInferExt, ValuePairs};
|
||||
use rustc_trait_selection::traits::ObligationCtxt;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::errors;
|
||||
|
||||
@@ -299,6 +298,7 @@ fn check_attributes(
|
||||
| AttributeKind::RustcAllocatorZeroed
|
||||
| AttributeKind::RustcAllocatorZeroedVariant { .. }
|
||||
| AttributeKind::RustcAsPtr(..)
|
||||
| AttributeKind::RustcAutodiff(..)
|
||||
| AttributeKind::RustcBodyStability { .. }
|
||||
| AttributeKind::RustcBuiltinMacro { .. }
|
||||
| AttributeKind::RustcCaptureAnalysis
|
||||
@@ -390,9 +390,6 @@ fn check_attributes(
|
||||
Attribute::Unparsed(attr_item) => {
|
||||
style = Some(attr_item.style);
|
||||
match attr.path().as_slice() {
|
||||
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
|
||||
self.check_autodiff(hir_id, attr, span, target)
|
||||
}
|
||||
[
|
||||
// ok
|
||||
sym::allow
|
||||
@@ -402,8 +399,6 @@ fn check_attributes(
|
||||
| sym::forbid
|
||||
// internal
|
||||
| sym::rustc_on_unimplemented
|
||||
| sym::rustc_layout
|
||||
| sym::rustc_autodiff
|
||||
// crate-level attrs, are checked below
|
||||
| sym::feature,
|
||||
..
|
||||
@@ -1863,18 +1858,6 @@ fn check_mix_no_mangle_export(&self, hir_id: HirId, attrs: &[Attribute]) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if `#[autodiff]` is applied to an item other than a function item.
|
||||
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
|
||||
debug!("check_autodiff");
|
||||
match target {
|
||||
Target::Fn => {}
|
||||
_ => {
|
||||
self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
|
||||
self.abort.set(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_loop_match(&self, hir_id: HirId, attr_span: Span, target: Target) {
|
||||
let node_span = self.tcx.hir_span(hir_id);
|
||||
|
||||
|
||||
@@ -19,14 +19,6 @@
|
||||
#[diag("`#[diagnostic::do_not_recommend]` can only be placed on trait implementations")]
|
||||
pub(crate) struct IncorrectDoNotRecommendLocation;
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag("`#[autodiff]` should be applied to a function")]
|
||||
pub(crate) struct AutoDiffAttr {
|
||||
#[primary_span]
|
||||
#[label("not a function")]
|
||||
pub attr_span: Span,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag("`#[loop_match]` should be applied to a loop")]
|
||||
pub(crate) struct LoopMatchAttr {
|
||||
|
||||
@@ -157,6 +157,8 @@
|
||||
Abi,
|
||||
AcqRel,
|
||||
Acquire,
|
||||
Active,
|
||||
ActiveOnly,
|
||||
Alignment,
|
||||
Arc,
|
||||
ArcWeak,
|
||||
@@ -213,6 +215,12 @@
|
||||
Deref,
|
||||
DispatchFromDyn,
|
||||
Display,
|
||||
Dual,
|
||||
DualOnly,
|
||||
Dualv,
|
||||
DualvOnly,
|
||||
Duplicated,
|
||||
DuplicatedOnly,
|
||||
DynTrait,
|
||||
Enum,
|
||||
Eq,
|
||||
@@ -310,6 +318,7 @@
|
||||
Slice,
|
||||
SliceIndex,
|
||||
Some,
|
||||
Source,
|
||||
SpanCtxt,
|
||||
Str,
|
||||
String,
|
||||
|
||||
Reference in New Issue
Block a user