Closed jgreener64 closed 1 year ago
Yeah this is precisely an active variables passed by value to jl_threadsfor
, per the warning.
Ah okay, I think I tuned that warning out after a while.
I mean we can keep the issue open as a reminder to finish that!
I just stumbled on that issue (got a valid gradients when using a vanilla loop, but got zeros gradients when looping with @threads
). Is there any thing to help on this? I'm somewhat exciting to be almost able to ditch custom manual backprop on trees :)
@wsmoses There appears to still have an issue working with threads with split mode.
The following first shows the non-threaded simple mutating kernel. Gradients works fine:
function f3(x, y)
for j in axes(x, 2)
for i = axes(x, 1)
y[i, j] *= x[i, j]^2
end
end
return nothing
end
x = zeros(3,4) .+ collect(2:4);
∂x = zeros(3,4);
y = ones(3,4);
∂y = zeros(3,4);
forward, backward = autodiff_thunk(
ReverseSplitNoPrimal,
Const{typeof(f3)},
Const,
Duplicated{typeof(x)},
Duplicated{typeof(y)},
);
tape, result, shadow_result = forward(Const(f3), Duplicated(x, ∂x), Duplicated(y, ∂y));
∂y = ones(3,4);
_ = backward(Const(f3), Duplicated(x, ∂x), Duplicated(y, ∂y), tape)
julia> ∂x
3×4 Matrix{Float64}:
4.0 4.0 4.0 4.0
6.0 6.0 6.0 6.0
8.0 8.0 8.0 8.0
By adding @threads
to the outer loop, gradients for x
become zeros:
function f3t(x, y)
@threads for j in axes(x, 2)
for i = axes(x, 1)
y[i, j] *= x[i, j]^2
end
end
return nothing
end
x = zeros(3,4) .+ collect(2:4);
∂x = zeros(3,4);
y = ones(3,4);
∂y = zeros(3,4);
forward, backward = autodiff_thunk(
ReverseSplitNoPrimal,
Const{typeof(f3t)},
Const,
Duplicated{typeof(x)},
Duplicated{typeof(y)},
);
tape, result, shadow_result = forward(Const(f3t), Duplicated(x, ∂x), Duplicated(y, ∂y));
∂y = ones(3,4);
_ = backward(Const(f3t), Duplicated(x, ∂x), Duplicated(y, ∂y), tape)
julia> ∂x
3×4 Matrix{Float64}:
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
By moving up ∂y
to prior to the forward
call, then the gradients would be fine. However, in the context of usage of the split mode, such proper values for ∂y
aren't known in advance so it's not possible to know beforehand. Also, it would be expected for both f3
and f3t
to behave the same.
You are required to pass the same arrays for both the forward and reverse functions.
Instead of doing ∂y = ones(3,4);
can you instead overwrite it and do ∂y .= ones(3,4);
@gaurav-arya since you were looking for things to do (and potentially necessary for custom rules that themselves call Enzyme), any appetite for documenting split mode?
You are required to pass the same arrays for both the forward and reverse functions.
It works! So the gradients in the non-threaded case was just coincidentally working fine? That got me distracted thinking threads were root cause, but such requirement for same array on forward and reverse make much sense. Thanks for quick reply!
Another zero gradient in structs issue. I am on Julia 1.8.5 and Enzyme main (2ccf4b). This works:
However when I change the function to use threading the gradient is zero: