EnzymeAD / rust

A rust fork to work towards Enzyme integration
https://www.rust-lang.org
Other
63 stars 9 forks source link

autodiff: dot product patterns broken in debug #136

Open jedbrown opened 1 month ago

jedbrown commented 1 month ago

Some common ways to implement a dot product fail to work in debug profile (but work in release).

#![feature(autodiff)]

const N: usize = 3;

#[autodiff(b_dot_local, Reverse, Duplicated, Duplicated, Active)]
fn dot_local(x: &[f64], y: &[f64]) -> f64 {
    x.iter().zip(y).map(|(x, y)| x * y).sum()

    // let mut sum = 0.0;
    // for i in 0..N {
    //     sum += x[i] * y[i];
    // }
    // sum

    //// Works always
    // x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
}

fn new_array<const N: usize>(start: i32) -> [f64; N] {
    let mut x = [0.0; N];
    for i in 0..N {
        x[i] = (start as i64 + i as i64) as f64;
    }
    x
}

fn main() {
    let rank = 1;
    let x = new_array::<N>(10 * rank);
    let y = new_array::<N>(100 * rank);
    let r = dot_local(&x, &y);
    println!("[{}] local: {}", rank, r);

    let mut bx = [0.0; N];
    let mut by = [0.0; N];
    if true {
        let r = b_dot_local(&x, &mut bx, &y, &mut by, 1.0);
        println!("[{}] r: {}, bx: {:?}, by: {:?}", rank, r, bx, by);
    }
}

This yields a huge output ending with

$ cargo +enzyme r --bin=dot
[...]
Illegal updateAnalysis prev:{[-1]:Pointer, [-1,-1]:Float@double} new: {[-1]:Integer}
val:   %18 = ptrtoint ptr %17 to i64, !dbg !991 origin=  store i64 %18, ptr %9, align 8, !dbg !992

If I switch to the version that increments sum, I get

$ cargo +enzyme r --bin=dot
[...]
source_id: DefId(0:11 ~ dot[742f]::dot_local)
num_fnc_args: 4
input_activity.len(): 4
error: /home/jed/src/rust-enzyme/library/core/src/iter/range.rs:821:6: in function preprocess__ZN4core4iter5range101_$LT$impl$u20$core..iter..traits..iterator..Iterator$u20$for$u20$core..ops..range..Range$LT$A$GT$$GT$4next17h8f76b017d0dedd06E { i64, i64 } (ptr): Enzyme: Cannot deduce type of insertvalue ins   %6 = insertvalue { i64, i64 } %5, i64 %4, 1, !dbg !914 size: 8 TT: {}

Meta

rustc --version --verbose:

rustc 1.77.0-nightly (63fa87211 2024-07-16)
binary: rustc
commit-hash: 63fa87211c25d89e585d76e2a15724a600eff903
commit-date: 2024-07-16
host: x86_64-unknown-linux-gnu
release: 1.77.0-nightly
LLVM version: 17.0.6
ZuseZ4 commented 1 month ago

@jedbrown Using the new ENZYME_OPT dbg features as described here, I just got this mre:

https://fwd.gymni.ch/7x2UOu

wsmoses commented 1 month ago

Illegal type analysis usually implies that either invalid metadata is being provided to Enzyme from this repo or one is differentiating a union with strict aliasing turned on (the default). It could also imply there is a bithack somewhere that Enzyme fails to understand.

My guess is that the first is most likely, if you want to check that first.

On Thu, Jul 18, 2024 at 12:29 PM Manuel Drehwald @.***> wrote:

@jedbrown https://github.com/jedbrown Using the new ENZYME_OPT dbg features as described here https://enzyme.mit.edu/index.fcgi/rust/Debugging.html#backend-crashes, I just got this mre:

https://fwd.gymni.ch/7x2UOu

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/rust/issues/136#issuecomment-2237033044, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFWMGC2IFSOVEAJIA3ZM7UNZAVCNFSM6AAAAABLDAII4KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMZXGAZTGMBUGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

wsmoses commented 1 month ago

It could also imply a bad inductive type propagation rule, but I would recommend checking the inputs first since otherwise “garbage in, garbage out”

ZuseZ4 commented 1 month ago

For case two, the union. I don't want to use the enzyme flag since I assume it would prevent Enzyme from catching some type errors. Could you introduce a TA Type for Unions, which disables relevant checks for one specific variable, that I could create during lowering from Rust? Or wouldn't that help?

wsmoses commented 1 month ago

We may be able to design something, but I'm not sure it will be as simple as that. Essentially because we propagate across all dataflow knowing ...

Okay a quick way we may be able to mark individual llvm values being unions, by basically forbidding type info to pass through it [just like right now how the strict aliasing disabling works by only propagating values which postdominate].

Maybe you could write a design doc / proposal and we can talk through the implications in a future Enzyme open design meeting?

wsmoses commented 1 month ago

Separately @ZuseZ4 something potentially useful here for debugging this.

If you turn on -enzyme-print-type we will print all of the type deductions for all values.

At the end of the error log here you see the invalid type combination.

Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Pointer} new: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
val:   %9 = alloca ptr, align 8 origin=  store ptr %16, ptr %9, align 8, !dbg !137

e.g. ptr to ptr and ptr to int.

Maybe it would be helpful to add another debug tool here. We have the debug info so we could at minimum say "hey this is the source line the type error was deduced at" [which already can let us know if its a union more easily].

But also we can look through the log and see how that int/pointer propagation gets into the variable, so the tool could yell "hey this is path A that knows an int, thoruhg these variables" and "this is path B that knows a pointer". This is, for example, what I would do [and also recommend you] to do to debug this -- so having a tool to automate this to make it easy would be awesome and make finding/fixing things much faster and user-friendly!

wsmoses commented 1 month ago

Also fwiw from these logs specifically the error appears to be in _ZN4core5slice29_$LT$impl$u20$$u5b$T$u5d$$GT$4iter17h16be0a2b33ef2db5E which takes an ptr, i64 as arguments. You mark the second one as an int. It gets transformed via inttoptr to be stored into a pointer, causing the inconsistency. Probably I'd say to not mark this function with type tree's for the int if that's the case

ZuseZ4 commented 1 month ago

Probably I'd say to not mark this function with type tree's for the int if that's the case

Thanks for the info, but that's the issue I have with TA for Enzyme. I don't emit these TT myself, I only annotate the outermost function getting differentiated and a few mem* calls. And even there I don't have the context information to ask "is this int also getting used as pointer", that would require more invasive rustc changes for which I probably get rejected (see e.g. https://rust-lang.zulipchat.com/#narrow/stream/435869-project-goals-2024h2/topic/sci-computing/near/452174200 for one rustc dev questioning if we aren't (already) too invasive. And such a change that takes context into consideration would make it much more complex). People probably also wouldn't like it if I go through every Rust std functionality (like with slice implementations here) and decide to add tt to the lowering process here and there. So I can completely stop posting tt for all Integers everywhere, but I currently don't see a easy path for disabling only tt for this specific case, which is why I haven't done anything to solve this in the past. Any ideas?

I generally feel like int2ptr/ptr2int are more common in rust, how much would we loose by flipping the Enzyme flag to always treat them the same? It would still be better than treating everying as a union. I assume that type confusions between i.e. float and ptr, or between int and float are much more rare.

ZuseZ4 commented 1 month ago

Maybe it would be helpful to add another debug tool here. We have the debug info so we could at minimum say "hey this is the source line the type error was deduced at" [which already can let us know if its a union more easily].

Yes, it's also already exposed, I have 3 ENZYME_PRINT<_TA/_AA> flags. Let me get it for this case.

wsmoses commented 1 month ago

ssue I have with TA for Enzyme. I don't emit these TT myself, I only annotate the outermost function getting differentiated and a few mem* calls. And even there I don't have the context information to ask "is this int also getting used as pointer", that

I dont know where this comes from rust code -- but I kind of assume changing an int to a pointer is unsafe rust. There's probably a lot of reasonable thresholds for when to emit vs not, but maybe unsafe code has less TT emitted?

wsmoses commented 1 month ago

Maybe it would be helpful to add another debug tool here. We have the debug info so we could at minimum say "hey this is the source line the type error was deduced at" [which already can let us know if its a union more easily].

Yes, it's also already exposed, I have 3 ENZYME_PRINT<_TA/_AA> flags. Let me get it for this case.

Sorry I didn't mean to say we need it for this case specifically [resolved by looking at the debug log alone]. But rather it might be useful to have a tool which takes in the logs generated by ENZYME_PRINT_TA / -enzyme-print-type and then told you the dataflow which caused the error.

ZuseZ4 commented 1 month ago

ssue I have with TA for Enzyme. I don't emit these TT myself, I only annotate the outermost function getting differentiated and a few mem* calls. And even there I don't have the context information to ask "is this int also getting used as pointer", that

I dont know where this comes from rust code -- but I kind of assume changing an int to a pointer is unsafe rust. There's probably a lot of reasonable thresholds for when to emit vs not, but maybe unsafe code has less TT emitted?

Unsafe Code does mainly exist on a high IR (THIR) not the MIR which we lower into LLVM-IR (and where we add tt). But if it's really only the i2p/p2i casts then maybe we can add extra tt info to these which adds whatever information enzyme's TA needs for this specific location (i.e. an info that the type is allowed to change here, or the info to overwrite what was previously believed to be an int with now a ptr type). The benefit is that I can then still lower all MIR types normally, and we would still catch tt mismatch errors if they happen at any other location then such a specifically i2p/p2i cast?

I can try to find a rustc dev who knows more here.

ZuseZ4 commented 1 month ago

@jedbrown with the latest updates in enzyme core, the second version (increments sum) now works. The first one still gives an incorrect TA update though.

jedbrown commented 1 month ago

Should I add that version to the test suite?