EnzymeAD / rust

A rust fork to work towards Enzyme integration
https://www.rust-lang.org
Other
53 stars 7 forks source link

Enzyme creates endless loop #38

Closed ZuseZ4 closed 10 months ago

ZuseZ4 commented 10 months ago
use autodiff::autodiff;

use std::io;

// Will be represented as {f32, i16, i16} when passed by reference
// will be represented as i64 if passed by value
struct Foo {
    c1: i16,
    a: f32,
    c2: i16,
}

#[autodiff(cos, Reverse, Active, Duplicated)]
fn sin(x: &Foo) -> f32 {
    assert!(x.c1 < x.c2);
    f32::sin(x.a)
}

fn main() {
    let mut s = String::new();
    io::stdin().read_line(&mut s).unwrap();
    let c2 = s.trim_end().parse::<i16>().unwrap();
    dbg!(c2);

    let foo = Foo { c1: 4, a: 3.14, c2 };
    let mut df_dfoo = Foo { c1: 4, a: 0.0, c2 };

    dbg!(df_dfoo.a);
    dbg!(cos(&foo, &mut df_dfoo, 1.0));
    dbg!(df_dfoo.a);
    dbg!(f32::cos(foo.a));
}
analyzing function preprocess__ZN6struct3sin17hc3800a79908d474dE
 + knowndata: ptr %0 : {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} - {}
 + retdata: {[-1]:Float@float}
updating analysis of val: ptr %0 current: {} new {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} from ptr %0 Changed=1 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} from ptr %0 Changed=0 legal=1
updating analysis of val:   %9 = tail call float @llvm.sin.f32(float %8) #102 current: {} new {[-1]:Float@float} from   %9 = tail call float @llvm.sin.f32(float %8) #102 Changed=1 legal=1
updating analysis of val:   %9 = tail call float @llvm.sin.f32(float %8) #102 current: {[-1]:Float@float} new {[-1]:Float@float} from   %9 = tail call float @llvm.sin.f32(float %8) #102 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=1 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=1 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %3 = load i16, ptr %2, align 4, !noundef !4 Changed=0 legal=1
updating analysis of val:   %3 = load i16, ptr %2, align 4, !noundef !4 current: {} new {[-1]:Integer} from   %3 = load i16, ptr %2, align 4, !noundef !4 Changed=1 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=1 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=1 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %5 = load i16, ptr %4, align 2, !noundef !4 Changed=0 legal=1
updating analysis of val:   %5 = load i16, ptr %4, align 2, !noundef !4 current: {} new {[-1]:Integer} from   %5 = load i16, ptr %4, align 2, !noundef !4 Changed=1 legal=1
updating analysis of val:   %6 = icmp slt i16 %3, %5 current: {} new {[-1]:Integer} from   %6 = icmp slt i16 %3, %5 Changed=1 legal=1
updating analysis of val:   %3 = load i16, ptr %2, align 4, !noundef !4 current: {[-1]:Integer} new {[-1]:Integer} from   %6 = icmp slt i16 %3, %5 Changed=0 legal=1
updating analysis of val:   %5 = load i16, ptr %4, align 2, !noundef !4 current: {[-1]:Integer} new {[-1]:Integer} from   %6 = icmp slt i16 %3, %5 Changed=0 legal=1
 skipping update into ptr %0 of {[-1]:Pointer} from   %8 = load float, ptr %0, align 4, !noundef !4
updating analysis of val:   %8 = load float, ptr %0, align 4, !noundef !4 current: {} new {[-1]:Float@float} from   %8 = load float, ptr %0, align 4, !noundef !4 Changed=1 legal=1
updating analysis of val:   %9 = tail call float @llvm.sin.f32(float %8) #102 current: {[-1]:Float@float} new {[-1]:Float@float} from   %9 = tail call float @llvm.sin.f32(float %8) #102 Changed=0 legal=1
updating analysis of val:   %8 = load float, ptr %0, align 4, !noundef !4 current: {[-1]:Float@float} new {[-1]:Float@float} from   %9 = tail call float @llvm.sin.f32(float %8) #102 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer} from   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val: ptr %0 current: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer} new {[-1]:Pointer, [-1,6]:Integer, [-1,7]:Integer} from   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 Changed=0 legal=1
updating analysis of val:   %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %3 = load i16, ptr %2, align 4, !noundef !4 Changed=0 legal=1
updating analysis of val:   %3 = load i16, ptr %2, align 4, !noundef !4 current: {[-1]:Integer} new {[-1]:Integer} from   %3 = load i16, ptr %2, align 4, !noundef !4 Changed=0 legal=1
updating analysis of val:   %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2 current: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} new {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer} from   %5 = load i16, ptr %4, align 2, !noundef !4 Changed=0 legal=1
updating analysis of val:   %5 = load i16, ptr %4, align 2, !noundef !4 current: {[-1]:Integer} new {[-1]:Integer} from   %5 = load i16, ptr %4, align 2, !noundef !4 Changed=0 legal=1
    Finished release [optimized] target(s) in 0.28s
ZuseZ4 commented 10 months ago
after simplification :
; Function Attrs: mustprogress noinline nonlazybind sanitize_hwaddress willreturn uwtable
define internal noundef float @preprocess__ZN6struct3sin17hec82b97809b00bbeE(ptr noalias nocapture noundef readonly align 4 dereferenceable(8) %0) unnamed_addr #100 {
  %2 = getelementptr inbounds %2, ptr %0, i64 0, i32 1
  %3 = load i16, ptr %2, align 4, !noundef !4
  %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 2
  %5 = load i16, ptr %4, align 2, !noundef !4
  %6 = icmp slt i16 %3, %5
  br i1 %6, label %7, label %10

7:                                                ; preds = %1
  %8 = load float, ptr %0, align 4, !noundef !4
  %9 = tail call float @llvm.sin.f32(float %8) #101
  ret float %9

10:                                               ; preds = %1
  tail call void @_ZN4core9panicking5panic17h9a6a4d6bf7daca76E(ptr noalias noundef nonnull readonly align 1 @anon.ecad6efadb6fdfe82af63b640f604887.17, i64 noundef 29, ptr noalias noundef nonnull readonly align 8 dereferenceable(24) @anon.ecad6efadb6fdfe82af63b640f604887.18) #102
  unreachable
}

; Function Attrs: mustprogress noinline nonlazybind sanitize_hwaddress willreturn memory(readwrite) uwtable
define internal void @diffe_ZN6struct3sin17hec82b97809b00bbeE(ptr noalias nocapture noundef readonly align 4 dereferenceable(8) %0, ptr nocapture align 4 %1, float %2) unnamed_addr #101 {
  %4 = getelementptr inbounds %2, ptr %0, i64 0, i32 1
  %5 = load i16, ptr %4, align 4, !alias.scope !48573, !noalias !48576, !noundef !4
  %6 = getelementptr inbounds %2, ptr %0, i64 0, i32 2
  %7 = load i16, ptr %6, align 2, !alias.scope !48573, !noalias !48576, !noundef !4
  %8 = icmp slt i16 %5, %7
  br i1 %8, label %9, label %15

9:                                                ; preds = %3
  %10 = load float, ptr %0, align 4, !alias.scope !48573, !noalias !48576, !noundef !4
  %11 = call fast float @llvm.cos.f32(float %10)
  %12 = fmul fast float %2, %11
  %13 = load float, ptr %1, align 4, !alias.scope !48576, !noalias !48573
  %14 = fadd fast float %13, %12
  store float %14, ptr %1, align 4, !alias.scope !48576, !noalias !48573
  ret void

15:                                               ; preds = %3
  tail call void @_ZN4core9panicking5panic17h9a6a4d6bf7daca76E(ptr noalias noundef nonnull readonly align 1 @anon.ecad6efadb6fdfe82af63b640f604887.17, i64 noundef 29, ptr noalias noundef nonnull readonly align 8 dereferenceable(24) @anon.ecad6efadb6fdfe82af63b640f604887.18) #102
  unreachable
}
ZuseZ4 commented 10 months ago

odd ?

wsmoses commented 10 months ago

Main reads from stdin, as discussed on call. When passing data it runs fine.