Skip to content

TypeTree support in autodiff #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::{Ty, TyKind};

Expand Down Expand Up @@ -84,6 +85,8 @@ pub struct AutoDiffItem {
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
Expand Down Expand Up @@ -275,14 +278,22 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
AutoDiffItem { source, target, attrs: self }
pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
1 change: 1 addition & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub enum Kind {
Half,
Float,
Double,
F128,
Unknown,
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
_src_align: Align,
size: RValue<'gcc>,
flags: MemFlags,
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_size_t(), false);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);

bx.lifetime_end(scratch, scratch_size);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
// Disables TypeTree generation
config::AutoDiff::NoTT => {}
}
}
// This helps with handling enums for now.
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};

use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

Expand Down Expand Up @@ -1118,11 +1119,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
Expand All @@ -1131,7 +1133,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};

// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ptr;

use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
Expand Down Expand Up @@ -254,6 +255,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
dest: PlaceRef<'tcx, &'ll Value>,
fnc_tree: FncTree,
) {
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
let mut ad_name: String = match attrs.mode {
Expand Down Expand Up @@ -330,6 +332,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args,
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

builder.store_to_place(call, dest.val);
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,9 @@ fn codegen_autodiff<'ll, 'tcx>(
&mut diff_attrs.input_activity,
);

let fnc_tree =
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));

// Build body
generate_enzyme_call(
bx,
Expand All @@ -1212,6 +1215,7 @@ fn codegen_autodiff<'ll, 'tcx>(
&val_arr,
diff_attrs.clone(),
result,
fnc_tree,
);
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mod llvm_util;
mod mono_item;
mod type_;
mod type_of;
mod typetree;
mod va_arg;
mod value;

Expand Down
183 changes: 182 additions & 1 deletion compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,36 @@
use libc::{c_char, c_uint};

use super::MetadataKindId;
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
use crate::llvm::{Bool, Builder};

// TypeTree types
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;

#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub(crate) struct EnzymeTypeTree {
_unused: [u8; 0],
}

#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub(crate) enum CConcreteType {
DT_Anything = 0,
DT_Integer = 1,
DT_Pointer = 2,
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably add f128 here on the Enzyme side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
DT_Unknown = 6,
}

pub(crate) struct TypeTree {
pub(crate) inner: CTypeTreeRef,
}

#[link(name = "llvm-wrapper", kind = "static")]
unsafe extern "C" {
// Enzyme
Expand Down Expand Up @@ -68,10 +95,33 @@ pub(crate) mod Enzyme_AD {

use libc::c_void;

use super::{CConcreteType, CTypeTreeRef, Context};

unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
}

// TypeTree functions
unsafe extern "C" {
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
}

unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
Expand Down Expand Up @@ -141,6 +191,57 @@ pub(crate) use self::Fallback_AD::*;
pub(crate) mod Fallback_AD {
#![allow(unused_variables)]

use libc::c_char;

use super::{CConcreteType, CTypeTreeRef, Context};

// TypeTree function fallbacks
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
unimplemented!()
}

pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
unimplemented!()
}

pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
unimplemented!()
}

pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
unimplemented!()
}

pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
unimplemented!()
}

pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
unimplemented!()
}

pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
unimplemented!()
}

pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
) {
unimplemented!()
}

pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
unimplemented!()
}

pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
unimplemented!()
}

pub(crate) fn set_inline(val: bool) {
unimplemented!()
}
Expand Down Expand Up @@ -169,3 +270,83 @@ pub(crate) mod Fallback_AD {
unimplemented!()
}
}

impl TypeTree {
pub(crate) fn new() -> TypeTree {
let inner = unsafe { EnzymeNewTypeTree() };
TypeTree { inner }
}

pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
TypeTree { inner }
}

pub(crate) fn merge(self, other: Self) -> Self {
unsafe {
EnzymeMergeTypeTree(self.inner, other.inner);
}
drop(other);
self
}

#[must_use]
pub(crate) fn shift(
self,
layout: &str,
offset: isize,
max_size: isize,
add_offset: usize,
) -> Self {
let layout = std::ffi::CString::new(layout).unwrap();

unsafe {
EnzymeTypeTreeShiftIndiciesEq(
self.inner,
layout.as_ptr(),
offset as i64,
max_size as i64,
add_offset as u64,
);
}

self
}
}

impl Clone for TypeTree {
fn clone(&self) -> Self {
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
TypeTree { inner }
}
}

impl std::fmt::Display for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
match cstr.to_str() {
Ok(x) => write!(f, "{}", x)?,
Err(err) => write!(f, "could not parse: {}", err)?,
}

// delete C string pointer
unsafe {
EnzymeTypeTreeToStringFree(ptr);
}

Ok(())
}
}

impl std::fmt::Debug for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Display>::fmt(self, f)
}
}

impl Drop for TypeTree {
fn drop(&mut self) {
unsafe { EnzymeFreeTypeTree(self.inner) }
}
}
Loading
Loading