mirror of
https://github.com/rust-lang/rust.git
synced 2026-05-31 13:40:15 +03:00
c29fb2e57e
TypeTree support in autodiff
# TypeTrees for Autodiff
## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
## Structure
```rust
TypeTree(Vec<Type>)
Type {
offset: isize, // byte offset (-1 = everywhere)
size: usize, // size in bytes
kind: Kind, // Float, Integer, Pointer, etc.
child: TypeTree // nested structure
}
```
## Example: `fn compute(x: &f32, data: &[f32]) -> f32`
**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float,
child: TypeTree::new()
}])
}])
```
**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = all elements
child: TypeTree::new()
}])
}])
```
**Output: `f32`**
```rust
TypeTree(vec![Type {
offset: -1, size: 4, kind: Float,
child: TypeTree::new()
}])
```
## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata
## What Enzyme Does With This Information:
Without TypeTrees (current state):
```llvm
; Enzyme sees generic LLVM IR:
define float ``@distance(ptr*`` %p1, ptr* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```
With TypeTrees (our implementation):
```llvm
define "enzyme_type"="{[]:Float@float}" float ``@distance(``
ptr "enzyme_type"="{[]:Pointer}" %p1,
ptr "enzyme_type"="{[]:Pointer}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```
# TypeTrees - Offset and -1 Explained
## Type Structure
```rust
Type {
offset: isize, // WHERE this type starts
size: usize, // HOW BIG this type is
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```
## Offset Values
### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**
```rust
struct Point {
x: f32, // offset 0, size 4
y: f32, // offset 4, size 4
id: i32, // offset 8, size 4
}
```
TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
])
```
Generates LLVM:
```llvm
"enzyme_type"="{[]:Float@float}"
```
### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**
#### Example 1: Array `[f32; 100]`
```rust
TypeTree(vec![Type {
offset: -1, // ALL positions
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
```
Instead of listing 100 separate Types with offsets `0,4,8,12...396`
#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL slice elements
size: 4, // each i32 is 4 bytes
kind: Integer
}])
}])
```
#### Example 3: Mixed Structure
```rust
struct Container {
header: i64, // offset 0
data: [f32; 1000], // offset 8, but elements use -1
}
```
```rust
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // ALL array elements
}])
}
])
```
194 lines
4.5 KiB
Rust
194 lines
4.5 KiB
Rust
use std::path::Path;
|
|
use std::{fmt, io};
|
|
|
|
use rustc_errors::codes::*;
|
|
use rustc_errors::{DiagArgName, DiagArgValue, DiagMessage};
|
|
use rustc_macros::{Diagnostic, Subdiagnostic};
|
|
use rustc_span::{Span, Symbol};
|
|
|
|
use crate::ty::{Instance, Ty};
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_drop_check_overflow, code = E0320)]
|
|
#[note]
|
|
pub(crate) struct DropCheckOverflow<'tcx> {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub ty: Ty<'tcx>,
|
|
pub overflow_ty: Ty<'tcx>,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_failed_writing_file)]
|
|
pub(crate) struct FailedWritingFile<'a> {
|
|
pub path: &'a Path,
|
|
pub error: io::Error,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_opaque_hidden_type_mismatch)]
|
|
pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
|
|
pub self_ty: Ty<'tcx>,
|
|
pub other_ty: Ty<'tcx>,
|
|
#[primary_span]
|
|
#[label]
|
|
pub other_span: Span,
|
|
#[subdiagnostic]
|
|
pub sub: TypeMismatchReason,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_unsupported_union)]
|
|
pub struct UnsupportedUnion {
|
|
pub ty_name: String,
|
|
}
|
|
|
|
// FIXME(autodiff): I should get used somewhere
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_autodiff_unsafe_inner_const_ref)]
|
|
pub struct AutodiffUnsafeInnerConstRef<'tcx> {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub ty: Ty<'tcx>,
|
|
}
|
|
|
|
#[derive(Subdiagnostic)]
|
|
pub enum TypeMismatchReason {
|
|
#[label(middle_conflict_types)]
|
|
ConflictType {
|
|
#[primary_span]
|
|
span: Span,
|
|
},
|
|
#[note(middle_previous_use_here)]
|
|
PreviousUse {
|
|
#[primary_span]
|
|
span: Span,
|
|
},
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_recursion_limit_reached)]
|
|
#[help]
|
|
pub(crate) struct RecursionLimitReached<'tcx> {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub ty: Ty<'tcx>,
|
|
pub suggested_limit: rustc_hir::limit::Limit,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_const_eval_non_int)]
|
|
pub(crate) struct ConstEvalNonIntError {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_strict_coherence_needs_negative_coherence)]
|
|
pub(crate) struct StrictCoherenceNeedsNegativeCoherence {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
#[label]
|
|
pub attr_span: Option<Span>,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_requires_lang_item)]
|
|
pub(crate) struct RequiresLangItem {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub name: Symbol,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_const_not_used_in_type_alias)]
|
|
pub(super) struct ConstNotUsedTraitAlias {
|
|
pub ct: String,
|
|
#[primary_span]
|
|
pub span: Span,
|
|
}
|
|
|
|
pub struct CustomSubdiagnostic<'a> {
|
|
pub msg: fn() -> DiagMessage,
|
|
pub add_args: Box<dyn FnOnce(&mut dyn FnMut(DiagArgName, DiagArgValue)) + 'a>,
|
|
}
|
|
|
|
impl<'a> CustomSubdiagnostic<'a> {
|
|
pub fn label(x: fn() -> DiagMessage) -> Self {
|
|
Self::label_and_then(x, |_| {})
|
|
}
|
|
pub fn label_and_then<F: FnOnce(&mut dyn FnMut(DiagArgName, DiagArgValue)) + 'a>(
|
|
msg: fn() -> DiagMessage,
|
|
f: F,
|
|
) -> Self {
|
|
Self { msg, add_args: Box::new(move |x| f(x)) }
|
|
}
|
|
}
|
|
|
|
impl fmt::Debug for CustomSubdiagnostic<'_> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("CustomSubdiagnostic").finish_non_exhaustive()
|
|
}
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
pub enum LayoutError<'tcx> {
|
|
#[diag(middle_layout_unknown)]
|
|
Unknown { ty: Ty<'tcx> },
|
|
|
|
#[diag(middle_layout_too_generic)]
|
|
TooGeneric { ty: Ty<'tcx> },
|
|
|
|
#[diag(middle_layout_size_overflow)]
|
|
Overflow { ty: Ty<'tcx> },
|
|
|
|
#[diag(middle_layout_simd_too_many)]
|
|
SimdTooManyLanes { ty: Ty<'tcx>, max_lanes: u64 },
|
|
|
|
#[diag(middle_layout_simd_zero_length)]
|
|
SimdZeroLength { ty: Ty<'tcx> },
|
|
|
|
#[diag(middle_layout_normalization_failure)]
|
|
NormalizationFailure { ty: Ty<'tcx>, failure_ty: String },
|
|
|
|
#[diag(middle_layout_cycle)]
|
|
Cycle,
|
|
|
|
#[diag(middle_layout_references_error)]
|
|
ReferencesError,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_erroneous_constant)]
|
|
pub(crate) struct ErroneousConstant {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_type_length_limit)]
|
|
#[help(middle_consider_type_length_limit)]
|
|
pub(crate) struct TypeLengthLimit<'tcx> {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub instance: Instance<'tcx>,
|
|
pub type_length: usize,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_max_num_nodes_in_valtree)]
|
|
pub(crate) struct MaxNumNodesInValtree {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub global_const_id: String,
|
|
}
|
|
|
|
#[derive(Diagnostic)]
|
|
#[diag(middle_invalid_const_in_valtree)]
|
|
#[note]
|
|
pub(crate) struct InvalidConstInValtree {
|
|
#[primary_span]
|
|
pub span: Span,
|
|
pub global_const_id: String,
|
|
}
|