EnzymeAD / rust

A rust fork to work towards Enzyme integration
https://www.rust-lang.org
Other
64 stars 9 forks source link

Can't get `Reverse` to work #151

Open marcpabst opened 1 month ago

marcpabst commented 1 month ago

This example, taken directly from the docs, either always returns 0 for the derivative (debug) or does not even compile (release):

#![feature(autodiff)]

#[autodiff(df, Reverse, Active, Active, Active)]
 fn f(x: f32, y: f32) -> f32 {
     x * x + 3.0 * y
 }

 fn main() {
     let (x, y) = (5.0, 7.0);
     let (z, bx, by) = df(x, y, 1.0);
     assert_eq!(46.0, z);
     assert_eq!(10.0, bx);
     assert_eq!(3.0, by);
 }

Error in debug mode:

thread 'main' panicked at src/main.rs:12:6:
assertion `left == right` failed
  left: 10.0
 right: 0.0

Stack trace in release mode:

thread 'coordinator' panicked at compiler/rustc_codegen_llvm/src/back/write.rs:761:17:
assertion failed: llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer
stack backtrace:
   0:        0x101fcc848 - std::backtrace::Backtrace::create::h274c998bd3e77e61
   1:        0x1066b2624 - <alloc[9eedf9a5b017d5f0]::boxed::Box<rustc_driver_impl[4092d44f8fc6cf45]::install_ice_hook::{closure#0}> as core[4bc86d92358c3fc9]::ops::function::Fn<(&dyn for<'a, 'b> core[4bc86d92358c3fc9]::ops::function::Fn<(&'a core[4bc86d92358c3fc9]::panic::panic_info::PanicInfo<'b>,), Output = ()> + core[4bc86d92358c3fc9]::marker::Send + core[4bc86d92358c3fc9]::marker::Sync, &core[4bc86d92358c3fc9]::panic::panic_info::PanicInfo)>>::call
   2:        0x101ff7a64 - std::panicking::rust_panic_with_hook::h7b497338c9ec8662
   3:        0x101fd1628 - std::panicking::begin_panic_handler::{{closure}}::h880ad4dd90e1a8f8
   4:        0x101fd12a4 - std::sys_common::backtrace::__rust_end_short_backtrace::h10f05ee841af1dbc
   5:        0x101ff76c4 - _rust_begin_unwind
   6:        0x10204f588 - core::panicking::panic_fmt::heb8a25d7321e04b5
   7:        0x10204f610 - core::panicking::panic::h1b598ba5ca68e7a2
   8:        0x10699d8f0 - rustc_codegen_llvm[aebffebb323e7405]::back::write::enzyme_ad
   9:        0x10699ed9c - rustc_codegen_llvm[aebffebb323e7405]::back::write::differentiate
  10:        0x1069cded4 - <rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend as rustc_codegen_ssa[1b61eafb5d4c07bc]::traits::write::WriteBackendMethods>::autodiff
  11:        0x106987214 - <rustc_codegen_ssa[1b61eafb5d4c07bc]::back::lto::LtoModuleCodegen<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend>>::autodiff
  12:        0x1068e45f4 - std[94fe895d40b1402d]::sys_common::backtrace::__rust_begin_short_backtrace::<<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend as rustc_codegen_ssa[1b61eafb5d4c07bc]::traits::backend::ExtraBackendMethods>::spawn_named_thread<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::start_executing_work<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend>::{closure#5}, core[4bc86d92358c3fc9]::result::Result<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::CompiledModules, ()>>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::CompiledModules, ()>>
  13:        0x1068e99b8 - <<std[94fe895d40b1402d]::thread::Builder>::spawn_unchecked_<<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend as rustc_codegen_ssa[1b61eafb5d4c07bc]::traits::backend::ExtraBackendMethods>::spawn_named_thread<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::start_executing_work<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend>::{closure#5}, core[4bc86d92358c3fc9]::result::Result<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::CompiledModules, ()>>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::CompiledModules, ()>>::{closure#1} as core[4bc86d92358c3fc9]::ops::function::FnOnce<()>>::call_once::{shim:vtable#0}
  14:        0x101ffe8c4 - std::sys::unix::thread::Thread::new::thread_start::h530c9e8206466847
  15:        0x1805a6034 - __pthread_joiner_wake

rustc version: 1.75.0-nightly (bafd64d28 2024-08-01)
platform: aarch64-apple-darwinthread 'rustc' panicked at compiler/rustc_middle/src/util/bug.rs:36:26:
/Users/marc/rust/compiler/rustc_codegen_ssa/src/back/write.rs:1984:17: panic during codegen/LLVM phase
stack backtrace:
   0:        0x101fcc848 - std::backtrace::Backtrace::create::h274c998bd3e77e61
   1:        0x1066b2624 - <alloc[9eedf9a5b017d5f0]::boxed::Box<rustc_driver_impl[4092d44f8fc6cf45]::install_ice_hook::{closure#0}> as core[4bc86d92358c3fc9]::ops::function::Fn<(&dyn for<'a, 'b> core[4bc86d92358c3fc9]::ops::function::Fn<(&'a core[4bc86d92358c3fc9]::panic::panic_info::PanicInfo<'b>,), Output = ()> + core[4bc86d92358c3fc9]::marker::Send + core[4bc86d92358c3fc9]::marker::Sync, &core[4bc86d92358c3fc9]::panic::panic_info::PanicInfo)>>::call
   2:        0x101ff7a64 - std::panicking::rust_panic_with_hook::h7b497338c9ec8662
   3:        0x1086acc90 - std[94fe895d40b1402d]::panicking::begin_panic::<alloc[9eedf9a5b017d5f0]::string::String>::{closure#0}
   4:        0x1086ac6e8 - std[94fe895d40b1402d]::sys_common::backtrace::__rust_end_short_backtrace::<std[94fe895d40b1402d]::panicking::begin_panic<alloc[9eedf9a5b017d5f0]::string::String>::{closure#0}, !>
   5:        0x108c177f4 - std[94fe895d40b1402d]::panicking::begin_panic::<alloc[9eedf9a5b017d5f0]::string::String>
   6:        0x108640758 - rustc_middle[96f22e03556ee0e5]::util::bug::opt_span_bug_fmt::<rustc_span[3fa53b66ea663371]::span_encoding::Span>::{closure#0}
   7:        0x10863d7b4 - rustc_middle[96f22e03556ee0e5]::ty::context::tls::with_opt::<rustc_middle[96f22e03556ee0e5]::util::bug::opt_span_bug_fmt<rustc_span[3fa53b66ea663371]::span_encoding::Span>::{closure#0}, !>::{closure#0}
   8:        0x10863d780 - rustc_middle[96f22e03556ee0e5]::ty::context::tls::with_context_opt::<rustc_middle[96f22e03556ee0e5]::ty::context::tls::with_opt<rustc_middle[96f22e03556ee0e5]::util::bug::opt_span_bug_fmt<rustc_span[3fa53b66ea663371]::span_encoding::Span>::{closure#0}, !>::{closure#0}, !>
   9:        0x108c187e8 - rustc_middle[96f22e03556ee0e5]::util::bug::bug_fmt
  10:        0x1069b6b94 - <rustc_session[25050e0163d9f6fb]::session::Session>::time::<rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::CompiledModules, <rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::OngoingCodegen<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend>>::join::{closure#0}>
  11:        0x1068eadfc - <rustc_codegen_ssa[1b61eafb5d4c07bc]::back::write::OngoingCodegen<rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend>>::join
  12:        0x1069ce8a4 - <rustc_codegen_llvm[aebffebb323e7405]::LlvmCodegenBackend as rustc_codegen_ssa[1b61eafb5d4c07bc]::traits::backend::CodegenBackend>::join_codegen
  13:        0x10685e458 - <rustc_interface[90b1e0870df88c87]::queries::Linker>::link
  14:        0x1066a47fc - rustc_span[3fa53b66ea663371]::set_source_map::<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_interface[90b1e0870df88c87]::interface::run_compiler<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_driver_impl[4092d44f8fc6cf45]::run_compiler::{closure#0}>::{closure#0}::{closure#0}>
  15:        0x1066e1070 - <scoped_tls[df23976295fd3450]::ScopedKey<rustc_span[3fa53b66ea663371]::SessionGlobals>>::set::<rustc_interface[90b1e0870df88c87]::util::run_in_thread_pool_with_globals<rustc_interface[90b1e0870df88c87]::interface::run_compiler<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_driver_impl[4092d44f8fc6cf45]::run_compiler::{closure#0}>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>
  16:        0x1066a4bb4 - rustc_span[3fa53b66ea663371]::create_session_globals_then::<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_interface[90b1e0870df88c87]::util::run_in_thread_pool_with_globals<rustc_interface[90b1e0870df88c87]::interface::run_compiler<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_driver_impl[4092d44f8fc6cf45]::run_compiler::{closure#0}>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}>
  17:        0x1066d94a0 - std[94fe895d40b1402d]::sys_common::backtrace::__rust_begin_short_backtrace::<rustc_interface[90b1e0870df88c87]::util::run_in_thread_with_globals<rustc_interface[90b1e0870df88c87]::util::run_in_thread_pool_with_globals<rustc_interface[90b1e0870df88c87]::interface::run_compiler<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_driver_impl[4092d44f8fc6cf45]::run_compiler::{closure#0}>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>
  18:        0x1066e8a3c - <<std[94fe895d40b1402d]::thread::Builder>::spawn_unchecked_<rustc_interface[90b1e0870df88c87]::util::run_in_thread_with_globals<rustc_interface[90b1e0870df88c87]::util::run_in_thread_pool_with_globals<rustc_interface[90b1e0870df88c87]::interface::run_compiler<core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>, rustc_driver_impl[4092d44f8fc6cf45]::run_compiler::{closure#0}>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#0}::{closure#0}, core[4bc86d92358c3fc9]::result::Result<(), rustc_span[3fa53b66ea663371]::ErrorGuaranteed>>::{closure#1} as core[4bc86d92358c3fc9]::ops::function::FnOnce<()>>::call_once::{shim:vtable#0}
  19:        0x101ffe8c4 - std::sys::unix::thread::Thread::new::thread_start::h530c9e8206466847
  20:        0x1805a6034 - __pthread_joiner_wake
jedbrown commented 1 month ago

This one is a known bug (see cfg(broken)): https://github.com/EnzymeAD/rustbook/blob/main/samples/tests/reverse/mod.rs#L139-L156

But you make a very good point that we shouldn't have it as an example in the book until it is fixed. @ZuseZ4 Do you know how hard this one is to fix?

ZuseZ4 commented 1 month ago

Iirc pretty hard and only a very limited benefit. We would literally have to hardcode this case (two f32 getting passed as one f64 or i64, I don't remember which). So any case passing e.g. f32, i32, f32 would still not be handled. And in the same time that it takes to fix it (if I even manage to, there might be Enzyme core bugs that need to be updated too for it) I could help with getting more valuable features like MacOS support in, or work on the upstreaming. Here are more cases that fail and for each we would need an own fix, since it's a special case in rustc where rustc applies some optimization, changing the type: https://github.com/EnzymeAD/rust/issues/105 So I think it's better to just remove the Reverse mode docs and add some examples that work.

marcpabst commented 1 month ago

So just that I understand correctly: This is an issue with a specific combination of inputs due to how rustc optimises certain things?

wsmoses commented 1 month ago

Is there a calling conv flag inside rust somewhere that can be used to enforce args not getting coalesed? Alternatively you could forcibly pass all args by reference and then shim

ZuseZ4 commented 1 month ago

Yep, but it's one of the cases where a hack will take time and has it's own downsides, so I currently still lean towards a more proper fix, where I check on a more granular level what rustc does and only disable those that are too magic for enzyme to reasonably handle. Also I have it as a probably easy onboarding task in mind, where everyone can try to add rust handling for one more special case. Since rustc only has a limited number of these optimizations and doesn't change those much, it wouldn't become a never ending story.