EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
450 stars 63 forks source link

Active variables passed by value to jl_threadsfor #675

Closed jgreener64 closed 1 year ago

jgreener64 commented 1 year ago

Another zero gradient in structs issue. I am on Julia 1.8.5 and Enzyme main (2ccf4b). This works:

using Enzyme

const n_threads = Threads.nthreads()
const n = 100

struct I
    x::Float64
end

function force(inter, c1, c2)
    return inter.x * abs(c2 - c1)
end

function f_single(inter, coord)
    out_sum_threads = zeros(n_threads)
    for thread_id in 1:n_threads
        for i in thread_id:n_threads:n
            for j in (i + 1):n
                v = force(inter, coord[i], coord[j])
                out_sum_threads[thread_id] += v
            end
        end
    end
    return sum(out_sum_threads)
end

inter = I(5.0)
coord = rand(n)

d_coord = zeros(n)
autodiff(Reverse, f_single, Active, Active(inter), Duplicated(coord, d_coord))[1][1]
I(1618.0622389056841)

However when I change the function to use threading the gradient is zero:

function f_multi(inter, coord)
    out_sum_threads = zeros(n_threads)
    Threads.@threads for thread_id in 1:n_threads
        for i in thread_id:n_threads:n
            for j in (i + 1):n
                v = force(inter, coord[i], coord[j])
                out_sum_threads[thread_id] += v
            end
        end
    end
    return sum(out_sum_threads)
end

d_coord = zeros(n)
autodiff(Reverse, f_multi, Active, Active(inter), Duplicated(coord, d_coord))[1][1]
 caching call:   %46 = call fastcc i64 @julia_steprange_last_4852(i64 signext %45) #19, !dbg !112
┌ Warning: active variables passed by value to jl_threadsfor are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Z5kZC/src/utils.jl:35
I(0.0)
wsmoses commented 1 year ago

Yeah this is precisely an active variables passed by value to jl_threadsfor, per the warning.

jgreener64 commented 1 year ago

Ah okay, I think I tuned that warning out after a while.

wsmoses commented 1 year ago

I mean we can keep the issue open as a reminder to finish that!

jeremiedb commented 1 year ago

I just stumbled on that issue (got a valid gradients when using a vanilla loop, but got zeros gradients when looping with @threads). Is there any thing to help on this? I'm somewhat exciting to be almost able to ditch custom manual backprop on trees :)

jeremiedb commented 1 year ago

@wsmoses There appears to still have an issue working with threads with split mode.

The following first shows the non-threaded simple mutating kernel. Gradients works fine:

function f3(x, y)
    for j in axes(x, 2)
        for i = axes(x, 1)
            y[i, j] *= x[i, j]^2
        end
    end
    return nothing
end

x = zeros(3,4) .+ collect(2:4);
∂x = zeros(3,4);
y = ones(3,4);
∂y = zeros(3,4);

forward, backward = autodiff_thunk(
    ReverseSplitNoPrimal,
    Const{typeof(f3)},
    Const,
    Duplicated{typeof(x)},
    Duplicated{typeof(y)},
);

tape, result, shadow_result = forward(Const(f3), Duplicated(x, ∂x), Duplicated(y, ∂y));
∂y = ones(3,4);
_ = backward(Const(f3), Duplicated(x, ∂x), Duplicated(y, ∂y), tape)

julia> ∂x
3×4 Matrix{Float64}:
 4.0  4.0  4.0  4.0
 6.0  6.0  6.0  6.0
 8.0  8.0  8.0  8.0

By adding @threads to the outer loop, gradients for x become zeros:

function f3t(x, y)
    @threads for j in axes(x, 2)
        for i = axes(x, 1)
            y[i, j] *= x[i, j]^2
        end
    end
    return nothing
end

x = zeros(3,4) .+ collect(2:4);
∂x = zeros(3,4);
y = ones(3,4);
∂y = zeros(3,4);

forward, backward = autodiff_thunk(
    ReverseSplitNoPrimal,
    Const{typeof(f3t)},
    Const,
    Duplicated{typeof(x)},
    Duplicated{typeof(y)},
);

tape, result, shadow_result = forward(Const(f3t), Duplicated(x, ∂x), Duplicated(y, ∂y));
∂y = ones(3,4);
_ = backward(Const(f3t), Duplicated(x, ∂x), Duplicated(y, ∂y), tape)

julia> ∂x
3×4 Matrix{Float64}:
 0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0

By moving up ∂y to prior to the forward call, then the gradients would be fine. However, in the context of usage of the split mode, such proper values for ∂y aren't known in advance so it's not possible to know beforehand. Also, it would be expected for both f3 and f3t to behave the same.

wsmoses commented 1 year ago

You are required to pass the same arrays for both the forward and reverse functions.

Instead of doing ∂y = ones(3,4); can you instead overwrite it and do ∂y .= ones(3,4);

wsmoses commented 1 year ago

@gaurav-arya since you were looking for things to do (and potentially necessary for custom rules that themselves call Enzyme), any appetite for documenting split mode?

jeremiedb commented 1 year ago

You are required to pass the same arrays for both the forward and reverse functions.

It works! So the gradients in the non-threaded case was just coincidentally working fine? That got me distracted thinking threads were root cause, but such requirement for same array on forward and reverse make much sense. Thanks for quick reply!