Closed avik-pal closed 2 months ago
We should be caching the parameter gradients loss function compiled trace and such but this should be good initial version, we anyways need a redesign of the training API later on.
Need to wait for SciMLSensitivity https://github.com/SciML/SciMLSensitivity.jl/pull/1046 before the doc build goes through
Structural in what way
On Sun, May 12, 2024 at 4:19 PM Avik Pal @.***> wrote:
@.**** commented on this pull request.
In ext/LuxEnzymeExt.jl https://github.com/LuxDL/Lux.jl/pull/640#discussion_r1597737388:
+using Setfield: @set! + +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
- ts::Lux.Experimental.TrainState) where {F}
- dps = Enzyme.make_zero(ts.parameters)
- fwd, rev = Enzyme.autodiff_thunk(
- Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)},
- Enzyme.Active, Enzyme.Const{typeof(ts.model)},
- Enzyme.Duplicated{typeof(ts.parameters)},
- Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)})
- tape, (loss, st_new, stats), shadow_result = fwd(
- Enzyme.Const(objective_function), Enzyme.Const(ts.model),
- Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data))
- rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model),
- Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data),
- (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape)
Is there a way to specify a structural zero instead of doing it like Enzyme.make_zero(st_new)?
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#discussion_r1597737388, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXB6MCCONUMM7PBSTIDZB72HZAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJRGQ4DCMJTGU . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>
Structural in what way
As in I want to say don't backpropagate wrt this value. For Zygote I would put a nothing
No activity annotations (eg to differentiate or not to differentiate) are presently at an argument or return level
On Sun, May 12, 2024 at 4:27 PM Avik Pal @.***> wrote:
Structural in what way
As in I want to say don't backpropagate wrt this value. For Zygote I would put a nothing
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106408780, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFSTI3GNL2CW6TFE63ZB73GHAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDQNZYGA . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>
So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two
Honestly I would just pass in a function which first calls the first function
On Sun, May 12, 2024 at 4:31 PM Avik Pal @.***> wrote:
So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106409766, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDT3BKPMVG27LUU7JDZB73T7AVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDSNZWGY . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>
you mean something like
function compute_gradients(........)
st_new_outer = Ref()
stats_outer = Ref()
function wrapper_function(args...)
y, st_new, stats = objective_function(args...)
st_new_outer[] = st_new
stats_outer[] = stats
return y
end
.....
end
Yeah
On Sun, May 12, 2024 at 5:00 PM Avik Pal @.***> wrote:
you mean something like
function compute_gradients(........) st_new_outer = Ref() stats_outer = Ref()
function wrapper_function(args...) y, st_new, stats = objective_function(args...) st_new_outer[] = st_new stats_outer[] = stats return y
end
.....end
— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106417769, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCLJVABSUB5YSFF6TTZB77BXAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYTONZWHE . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>
using ADTypes, Lux, Random, Enzyme, Optimisers
model = Chain(Conv((3, 3), 3 => 6), GroupNorm(6, 3, gelu), Conv((3, 3), 6 => 32),
BatchNorm(32, gelu), GlobalMeanPool(), FlattenLayer(), Dense(32, 1))
x = rand(Float32, 32, 32, 3, 4);
tstate = Lux.Experimental.TrainState(Xoshiro(0), model, Adam(0.001f0));
function obj_fn(model, ps, st, x)
y, st_new = model(x, ps, st)
return sum(abs2, y), st_new, (;)
end
grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
AutoEnzyme(), obj_fn, x, tstate);
grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
AutoEnzyme(), obj_fn, x, tstate_new);
@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate);
# 14.726 ms (461 allocations: 9.75 MiB)
@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate_new);
# 14.233 ms (447 allocations: 9.74 MiB)
Caching seems to work correctly.
Ok I did something wrong, it segfaulted the training test https://github.com/LuxDL/Lux.jl/actions/runs/9056705562/job/24879628489?pr=640#step:6:739
Locally things pass. Now we need to wait for SciMLSensitivity compats to be updated.
CI is not picking up on the latest SciMLSensitivity
Attention: Patch coverage is 74.19355%
with 16 lines
in your changes are missing coverage. Please review.
Project coverage is 87.16%. Comparing base (
64ba96d
) to head (8bdde08
).
Files | Patch % | Lines |
---|---|---|
src/utils.jl | 26.66% | 11 Missing :warning: |
ext/LuxEnzymeExt.jl | 86.36% | 3 Missing :warning: |
src/contrib/training.jl | 89.47% | 2 Missing :warning: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.