Files
rust/compiler/rustc_middle/src/error.rs
T
Matthias Krüger c29fb2e57e Rollup merge of #144197 - KMJ-007:type-tree, r=ZuseZ4
TypeTree support in autodiff

# TypeTrees for Autodiff

## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

## Structure
```rust
TypeTree(Vec<Type>)

Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}
```

## Example: `fn compute(x: &f32, data: &[f32]) -> f32`

**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])
```

**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])
```

**Output: `f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 4, kind: Float,
    child: TypeTree::new()
}])
```

## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata

## What Enzyme Does With This Information:

Without TypeTrees (current state):
```llvm
; Enzyme sees generic LLVM IR:
define float ``@distance(ptr*`` %p1, ptr* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```

With TypeTrees (our implementation):
```llvm
define "enzyme_type"="{[]:Float@float}" float ``@distance(``
    ptr "enzyme_type"="{[]:Pointer}" %p1,
    ptr "enzyme_type"="{[]:Pointer}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```

# TypeTrees - Offset and -1 Explained

## Type Structure

```rust
Type {
    offset: isize, // WHERE this type starts
    size: usize,   // HOW BIG this type is
    kind: Kind,    // WHAT KIND of data (Float, Int, Pointer)
    child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```

## Offset Values

### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**

```rust
struct Point {
    x: f32, // offset 0, size 4
    y: f32, // offset 4, size 4
    id: i32, // offset 8, size 4
}
```

TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
    Type { offset: 0, size: 4, kind: Float },   // x at byte 0
    Type { offset: 4, size: 4, kind: Float },   // y at byte 4
    Type { offset: 8, size: 4, kind: Integer }  // id at byte 8
])
```

Generates LLVM:
```llvm
"enzyme_type"="{[]:Float@float}"
```

### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**

#### Example 1: Array `[f32; 100]`
```rust
TypeTree(vec![Type {
    offset: -1, // ALL positions
    size: 4,    // each f32 is 4 bytes
    kind: Float, // every element is float
}])
```

Instead of listing 100 separate Types with offsets `0,4,8,12...396`

#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, // ALL slice elements
        size: 4,    // each i32 is 4 bytes
        kind: Integer
    }])
}])
```

#### Example 3: Mixed Structure
```rust
struct Container {
    header: i64,        // offset 0
    data: [f32; 1000],  // offset 8, but elements use -1
}
```

```rust
TypeTree(vec![
    Type { offset: 0, size: 8, kind: Integer }, // header
    Type { offset: 8, size: 4000, kind: Pointer,
        child: TypeTree(vec![Type {
            offset: -1, size: 4, kind: Float // ALL array elements
        }])
    }
])
```
2025-09-28 18:13:11 +02:00

194 lines
4.5 KiB
Rust

use std::path::Path;
use std::{fmt, io};
use rustc_errors::codes::*;
use rustc_errors::{DiagArgName, DiagArgValue, DiagMessage};
use rustc_macros::{Diagnostic, Subdiagnostic};
use rustc_span::{Span, Symbol};
use crate::ty::{Instance, Ty};
#[derive(Diagnostic)]
#[diag(middle_drop_check_overflow, code = E0320)]
#[note]
pub(crate) struct DropCheckOverflow<'tcx> {
#[primary_span]
pub span: Span,
pub ty: Ty<'tcx>,
pub overflow_ty: Ty<'tcx>,
}
#[derive(Diagnostic)]
#[diag(middle_failed_writing_file)]
pub(crate) struct FailedWritingFile<'a> {
pub path: &'a Path,
pub error: io::Error,
}
#[derive(Diagnostic)]
#[diag(middle_opaque_hidden_type_mismatch)]
pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
pub self_ty: Ty<'tcx>,
pub other_ty: Ty<'tcx>,
#[primary_span]
#[label]
pub other_span: Span,
#[subdiagnostic]
pub sub: TypeMismatchReason,
}
#[derive(Diagnostic)]
#[diag(middle_unsupported_union)]
pub struct UnsupportedUnion {
pub ty_name: String,
}
// FIXME(autodiff): I should get used somewhere
#[derive(Diagnostic)]
#[diag(middle_autodiff_unsafe_inner_const_ref)]
pub struct AutodiffUnsafeInnerConstRef<'tcx> {
#[primary_span]
pub span: Span,
pub ty: Ty<'tcx>,
}
#[derive(Subdiagnostic)]
pub enum TypeMismatchReason {
#[label(middle_conflict_types)]
ConflictType {
#[primary_span]
span: Span,
},
#[note(middle_previous_use_here)]
PreviousUse {
#[primary_span]
span: Span,
},
}
#[derive(Diagnostic)]
#[diag(middle_recursion_limit_reached)]
#[help]
pub(crate) struct RecursionLimitReached<'tcx> {
#[primary_span]
pub span: Span,
pub ty: Ty<'tcx>,
pub suggested_limit: rustc_hir::limit::Limit,
}
#[derive(Diagnostic)]
#[diag(middle_const_eval_non_int)]
pub(crate) struct ConstEvalNonIntError {
#[primary_span]
pub span: Span,
}
#[derive(Diagnostic)]
#[diag(middle_strict_coherence_needs_negative_coherence)]
pub(crate) struct StrictCoherenceNeedsNegativeCoherence {
#[primary_span]
pub span: Span,
#[label]
pub attr_span: Option<Span>,
}
#[derive(Diagnostic)]
#[diag(middle_requires_lang_item)]
pub(crate) struct RequiresLangItem {
#[primary_span]
pub span: Span,
pub name: Symbol,
}
#[derive(Diagnostic)]
#[diag(middle_const_not_used_in_type_alias)]
pub(super) struct ConstNotUsedTraitAlias {
pub ct: String,
#[primary_span]
pub span: Span,
}
pub struct CustomSubdiagnostic<'a> {
pub msg: fn() -> DiagMessage,
pub add_args: Box<dyn FnOnce(&mut dyn FnMut(DiagArgName, DiagArgValue)) + 'a>,
}
impl<'a> CustomSubdiagnostic<'a> {
pub fn label(x: fn() -> DiagMessage) -> Self {
Self::label_and_then(x, |_| {})
}
pub fn label_and_then<F: FnOnce(&mut dyn FnMut(DiagArgName, DiagArgValue)) + 'a>(
msg: fn() -> DiagMessage,
f: F,
) -> Self {
Self { msg, add_args: Box::new(move |x| f(x)) }
}
}
impl fmt::Debug for CustomSubdiagnostic<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CustomSubdiagnostic").finish_non_exhaustive()
}
}
#[derive(Diagnostic)]
pub enum LayoutError<'tcx> {
#[diag(middle_layout_unknown)]
Unknown { ty: Ty<'tcx> },
#[diag(middle_layout_too_generic)]
TooGeneric { ty: Ty<'tcx> },
#[diag(middle_layout_size_overflow)]
Overflow { ty: Ty<'tcx> },
#[diag(middle_layout_simd_too_many)]
SimdTooManyLanes { ty: Ty<'tcx>, max_lanes: u64 },
#[diag(middle_layout_simd_zero_length)]
SimdZeroLength { ty: Ty<'tcx> },
#[diag(middle_layout_normalization_failure)]
NormalizationFailure { ty: Ty<'tcx>, failure_ty: String },
#[diag(middle_layout_cycle)]
Cycle,
#[diag(middle_layout_references_error)]
ReferencesError,
}
#[derive(Diagnostic)]
#[diag(middle_erroneous_constant)]
pub(crate) struct ErroneousConstant {
#[primary_span]
pub span: Span,
}
#[derive(Diagnostic)]
#[diag(middle_type_length_limit)]
#[help(middle_consider_type_length_limit)]
pub(crate) struct TypeLengthLimit<'tcx> {
#[primary_span]
pub span: Span,
pub instance: Instance<'tcx>,
pub type_length: usize,
}
#[derive(Diagnostic)]
#[diag(middle_max_num_nodes_in_valtree)]
pub(crate) struct MaxNumNodesInValtree {
#[primary_span]
pub span: Span,
pub global_const_id: String,
}
#[derive(Diagnostic)]
#[diag(middle_invalid_const_in_valtree)]
#[note]
pub(crate) struct InvalidConstInValtree {
#[primary_span]
pub span: Span,
pub global_const_id: String,
}