EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Illegal type analysis error with `setindex!` + `ifelse` + mixed float types #809

Closed ToucheSir closed 6 months ago

ToucheSir commented 1 year ago

I found this when trying to diff through https://github.com/FluxML/NNlib.jl/blob/acf87f5316e7579ac1e7eb16a278f43a9ca435dc/src/softmax.jl#L115.

MWE:

using Enzyme

function f(out, x)
    out[1] = ifelse(isequal(x[1], Inf), -Inf, x[1])
    out[1]
end

x = ones(Float32, 1)
@show f(copy(x), copy(x))

dx = zero(x)
Enzyme.autodiff(
    Reverse,
    f,
    Active,
    Duplicated(copy(x), copy(dx)),
    Duplicated(x, dx)
)
@show x dx

Looking at the generated IR, it appears Julia is union splitting the result of the ifelse and only truncating if it turns out to be Float64. Which seems a little unnecessary since LLVM ends up promoting it to a double unconditionally anyhow. Strange codegen aside, this specific example should be relatively easy to fix on the NNlib side, but I suspect there are many more examples lurking out there.

Error: ``` ERROR: LoadError: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define float @preprocess_julia_f_1435_inner.1({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) local_unnamed_addr #6 !dbg !68 { entry: %2 = alloca float, align 4 %.0.sroa_cast4 = bitcast float* %2 to i8* call void @llvm.lifetime.start.p0i8(i64 4, i8* %.0.sroa_cast4) %3 = call {}*** @julia.get_pgcstack() #7 %4 = bitcast {} addrspace(10)* %1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !69 %5 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %4 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !69 %6 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %5, i64 0, i32 1, !dbg !69 %7 = load i64, i64 addrspace(11)* %6, align 8, !dbg !69, !tbaa !12, !range !17, !alias.scope !18, !noalias !21 %.not = icmp eq i64 %7, 0, !dbg !69 br i1 %.not, label %oob.i, label %idxend2.i, !dbg !69 L16.i: ; preds = %idxend2.i %8 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !72 %9 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %8 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !72 %10 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %9, i64 0, i32 1, !dbg !72 %11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !72, !tbaa !12, !range !17, !alias.scope !18, !noalias !21 %.not10 = icmp eq i64 %11, 0, !dbg !72 br i1 %.not10, label %oob3.i, label %idxend4.i, !dbg !72 L21.i: ; preds = %idxend2.i %12 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !72 %13 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %12 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !72 %14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %13, i64 0, i32 1, !dbg !72 %15 = load i64, i64 addrspace(11)* %14, align 8, !dbg !72, !tbaa !12, !range !17, !alias.scope !18, !noalias !21 %.not13 = icmp eq i64 %15, 0, !dbg !72 br i1 %.not13, label %oob7.i, label %idxend8.i, !dbg !72 oob.i: ; preds = %entry %16 = alloca i64, align 8, !dbg !69 store i64 1, i64* %16, align 8, !dbg !69, !noalias !73 %17 = addrspacecast {} addrspace(10)* %1 to {} addrspace(12)*, !dbg !69 call void @ijl_bounds_error_ints({} addrspace(12)* noundef %17, i64* noundef nonnull align 8 %16, i64 noundef 1) #8, !dbg !69 unreachable, !dbg !69 idxend2.i: ; preds = %entry %18 = bitcast {} addrspace(10)* %1 to float addrspace(13)* addrspace(10)*, !dbg !69 %19 = addrspacecast float addrspace(13)* addrspace(10)* %18 to float addrspace(13)* addrspace(11)*, !dbg !69 %20 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %19, align 16, !dbg !69, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6 %21 = load float, float addrspace(13)* %20, align 4, !dbg !69, !tbaa !35, !alias.scope !38, !noalias !39 %22 = bitcast float %21 to i32, !dbg !77 %23 = icmp slt i32 %22, 0, !dbg !80 %24 = fcmp une float %21, 0x7FF0000000000000, !dbg !81 %25 = or i1 %24, %23, !dbg !83 store float %21, float* %2, align 4, !dbg !83, !noalias !73 %.0.sroa_cast3 = addrspacecast float* %2 to double addrspace(11)*, !dbg !83 %26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !83 br i1 %25, label %L16.i, label %L21.i, !dbg !84 oob3.i: ; preds = %L16.i %27 = alloca i64, align 8, !dbg !72 store i64 1, i64* %27, align 8, !dbg !72, !noalias !73 %28 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !72 call void @ijl_bounds_error_ints({} addrspace(12)* noundef %28, i64* noundef nonnull align 8 %27, i64 noundef 1) #8, !dbg !72 unreachable, !dbg !72 idxend4.i: ; preds = %L16.i %29 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !72 %30 = addrspacecast float addrspace(13)* addrspace(10)* %29 to float addrspace(13)* addrspace(11)*, !dbg !72 %31 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 16, !dbg !72, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6 %32 = bitcast double addrspace(11)* %26 to float addrspace(11)*, !dbg !72 %33 = load float, float addrspace(11)* %32, align 4, !dbg !72, !tbaa !58 store float %33, float addrspace(13)* %31, align 4, !dbg !72, !tbaa !35, !alias.scope !38, !noalias !85 br label %julia_f_1435_inner.exit, !dbg !84 oob7.i: ; preds = %L21.i %34 = alloca i64, align 8, !dbg !72 store i64 1, i64* %34, align 8, !dbg !72, !noalias !73 %35 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !72 call void @ijl_bounds_error_ints({} addrspace(12)* noundef %35, i64* noundef nonnull align 8 %34, i64 noundef 1) #8, !dbg !72 unreachable, !dbg !72 idxend8.i: ; preds = %L21.i %36 = load double, double addrspace(11)* %26, align 4, !dbg !86, !tbaa !58 %37 = fptrunc double %36 to float, !dbg !86 %38 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !72 %39 = addrspacecast float addrspace(13)* addrspace(10)* %38 to float addrspace(13)* addrspace(11)*, !dbg !72 %40 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %39, align 16, !dbg !72, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6 store float %37, float addrspace(13)* %40, align 4, !dbg !72, !tbaa !35, !alias.scope !38, !noalias !85 br label %julia_f_1435_inner.exit, !dbg !84 julia_f_1435_inner.exit: ; preds = %idxend8.i, %idxend4.i %41 = phi float [ %37, %idxend8.i ], [ %33, %idxend4.i ] %.0.sroa_cast5 = bitcast float* %2 to i8*, !dbg !89 call void @llvm.lifetime.end.p0i8(i64 4, i8* %.0.sroa_cast5), !dbg !89 ret float %41, !dbg !90 } Type analysis state: %24 = fcmp une float %21, 0x7FF0000000000000, !dbg !51: {[-1]:Integer}, intvals: {} %25 = or i1 %24, %23, !dbg !55: {[-1]:Integer}, intvals: {} %6 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %5, i64 0, i32 1, !dbg !7: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {} %14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %13, i64 0, i32 1, !dbg !26: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {} %5 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %4 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %7 = load i64, i64 addrspace(11)* %6, align 8, !dbg !7, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {} %32 = bitcast double addrspace(11)* %26 to float addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {} %33 = load float, float addrspace(11)* %32, align 4, !dbg !26, !tbaa !58: {[-1]:Float@float}, intvals: {} float 0x7FF0000000000000: {[-1]:Float@float}, intvals: {} %29 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %30 = addrspacecast float addrspace(13)* addrspace(10)* %29 to float addrspace(13)* addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %31 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 16, !dbg !26, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} %26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !55: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {} %39 = addrspacecast float addrspace(13)* addrspace(10)* %38 to float addrspace(13)* addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %40 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %39, align 16, !dbg !26, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} %4 = bitcast {} addrspace(10)* %1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %36 = load double, double addrspace(11)* %26, align 4, !dbg !60, !tbaa !58: {[-1]:Float@double}, intvals: {} %37 = fptrunc double %36 to float, !dbg !60: {[-1]:Float@float}, intvals: {} %38 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} double 0xFFF0000000000000: {[-1]:Float@double}, intvals: {} i64 0: {[-1]:Anything}, intvals: {0,} %15 = load i64, i64 addrspace(11)* %14, align 8, !dbg !26, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {} %.not13 = icmp eq i64 %15, 0, !dbg !26: {[-1]:Integer}, intvals: {} %9 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %8 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !26, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {} %.not10 = icmp eq i64 %11, 0, !dbg !26: {[-1]:Integer}, intvals: {} %12 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %13 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %12 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %2 = alloca float, align 4: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} %.0.sroa_cast5 = bitcast float* %2 to i8*, !dbg !66: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {} {} addrspace(10)* %0: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} {} addrspace(10)* %1: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %21 = load float, float addrspace(13)* %20, align 4, !dbg !7, !tbaa !35, !alias.scope !38, !noalias !39: {[-1]:Float@float}, intvals: {} %22 = bitcast float %21 to i32, !dbg !40: {[-1]:Float@float}, intvals: {} %23 = icmp slt i32 %22, 0, !dbg !48: {[-1]:Integer}, intvals: {} %41 = phi float [ %37, %idxend8.i ], [ %33, %idxend4.i ]: {[-1]:Float@float}, intvals: {} %.0.sroa_cast3 = addrspacecast float* %2 to double addrspace(11)*, !dbg !55: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} %.0.sroa_cast4 = bitcast float* %2 to i8*: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} %10 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %9, i64 0, i32 1, !dbg !26: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {} %.not = icmp eq i64 %7, 0, !dbg !7: {[-1]:Integer}, intvals: {} %8 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} i32 0: {[-1]:Anything}, intvals: {0,} %18 = bitcast {} addrspace(10)* %1 to float addrspace(13)* addrspace(10)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %19 = addrspacecast float addrspace(13)* addrspace(10)* %18 to float addrspace(13)* addrspace(11)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {} %20 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %19, align 16, !dbg !7, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {} @_j_const1 = private unnamed_addr constant double 0xFFF0000000000000: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {} double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*): {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {} Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Float@float} new: {[-1]:Pointer, [-1,0]:Float@double} val: %26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !55 origin= %36 = load double, double addrspace(11)* %26, align 4, !dbg !60, !tbaa !58 Caused by: Stacktrace: [1] ifelse @ ./essentials.jl:575 [2] f @ .jl:4 [3] f @ .jl:0 Stacktrace: [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:5330 [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool) @ Enzyme.API ~/.julia/packages/Enzyme/YBQJk/src/api.jl:124 [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type) @ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:6927 [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8177 [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool) @ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8690 [6] _thunk @ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8687 [inlined] [7] cached_compilation @ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8725 [inlined] [8] #s287#191 @ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8783 [inlined] [9] var"#s287#191"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any) @ Enzyme.Compiler ./none:0 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any}) @ Core ./boot.jl:602 [11] thunk(::Val{0x00000000000082c2}, ::Type{Const{typeof(f)}}, ::Type{Active}, tt::Type{Tuple{Duplicated{Vector{Float32}}, Duplicated{Vector{Float32}}}}, ::Val{Enzyme.API.DEM_ReverseModeCombined}, ::Val{1}, ::Val{(false, false, false)}, ::Val{false}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8742 [12] autodiff(::EnzymeCore.ReverseMode{false}, ::Const{typeof(f)}, ::Type{Active}, ::Duplicated{Vector{Float32}}, ::Vararg{Duplicated{Vector{Float32}}}) @ Enzyme ~/.julia/packages/Enzyme/YBQJk/src/Enzyme.jl:199 [13] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(f), ::Type, ::Duplicated{Vector{Float32}}, ::Vararg{Duplicated{Vector{Float32}}}) @ Enzyme ~/.julia/packages/Enzyme/YBQJk/src/Enzyme.jl:214 [14] top-level scope @ .jl:12 in expression starting at .jl:12 ```
wsmoses commented 1 year ago

Yeah for the immediate future at least, you should go fix the union in NNLib here (which would also be a generic performance benefit without AD even).

toollu commented 1 year ago

I'm (on 0.11.2) also hitting

ERROR: Enzyme compilation failed due to illegal type analysis.
Current scope: 
; Function Attrs: mustprogress willreturn
define internal fastcc i64 @preprocess_julia_partition__9106({} addrspace(10)* noundef nonnull readonly align 16 dereferenceable(40) %0, i64 signext %1, i64 signext %2, i64 signext %3, { {} addrspace(10)*, i8, i8 } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %4, i8 zeroext %5, { {} addrspace(10)*, i8, i8 } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %6, i64 signext %7) unnamed_addr #80 !dbg !5321 {
top:

However

Caused by:
Stacktrace:
 [1] setindex!
   @ ./array.jl:969
 [2] partition!
   @ ./sort.jl:1004

Is this related or should I try to reduce to a MWE?

wsmoses commented 1 year ago

That is unrelated, a minimal example is helpful @toollu

wsmoses commented 6 months ago

Closing as a ifelse on a union of different types is considered unsupported atm