LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Enzyme Testing + Caching in `compute_gradients` #640

Closed avik-pal closed 2 months ago

avik-pal commented 2 months ago
avik-pal commented 2 months ago

We should be caching the parameter gradients loss function compiled trace and such but this should be good initial version, we anyways need a redesign of the training API later on.

avik-pal commented 2 months ago

Need to wait for SciMLSensitivity https://github.com/SciML/SciMLSensitivity.jl/pull/1046 before the doc build goes through

wsmoses commented 2 months ago

Structural in what way

On Sun, May 12, 2024 at 4:19 PM Avik Pal @.***> wrote:

@.**** commented on this pull request.

In ext/LuxEnzymeExt.jl https://github.com/LuxDL/Lux.jl/pull/640#discussion_r1597737388:

+using Setfield: @set! + +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,

  • ts::Lux.Experimental.TrainState) where {F}
  • dps = Enzyme.make_zero(ts.parameters)
  • fwd, rev = Enzyme.autodiff_thunk(
  • Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)},
  • Enzyme.Active, Enzyme.Const{typeof(ts.model)},
  • Enzyme.Duplicated{typeof(ts.parameters)},
  • Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)})
  • tape, (loss, st_new, stats), shadow_result = fwd(
  • Enzyme.Const(objective_function), Enzyme.Const(ts.model),
  • Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data))
  • rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model),
  • Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data),
  • (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape)

Is there a way to specify a structural zero instead of doing it like Enzyme.make_zero(st_new)?

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#discussion_r1597737388, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXB6MCCONUMM7PBSTIDZB72HZAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANJRGQ4DCMJTGU . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>

avik-pal commented 2 months ago

Structural in what way

As in I want to say don't backpropagate wrt this value. For Zygote I would put a nothing

wsmoses commented 2 months ago

No activity annotations (eg to differentiate or not to differentiate) are presently at an argument or return level

On Sun, May 12, 2024 at 4:27 PM Avik Pal @.***> wrote:

Structural in what way

As in I want to say don't backpropagate wrt this value. For Zygote I would put a nothing

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106408780, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXFSTI3GNL2CW6TFE63ZB73GHAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDQNZYGA . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>

avik-pal commented 2 months ago

So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two

wsmoses commented 2 months ago

Honestly I would just pass in a function which first calls the first function

On Sun, May 12, 2024 at 4:31 PM Avik Pal @.***> wrote:

So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106409766, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDT3BKPMVG27LUU7JDZB73T7AVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYDSNZWGY . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>

avik-pal commented 2 months ago

you mean something like

function compute_gradients(........)
    st_new_outer = Ref()
    stats_outer = Ref()

    function wrapper_function(args...)
        y, st_new, stats = objective_function(args...)
        st_new_outer[] = st_new
        stats_outer[] = stats
        return y
    end

    .....
end
wsmoses commented 2 months ago

Yeah

On Sun, May 12, 2024 at 5:00 PM Avik Pal @.***> wrote:

you mean something like

function compute_gradients(........) st_new_outer = Ref() stats_outer = Ref()

function wrapper_function(args...)
  y, st_new, stats = objective_function(args...)
  st_new_outer[] = st_new
  stats_outer[] = stats
  return y

end

.....end

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/pull/640#issuecomment-2106417769, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCLJVABSUB5YSFF6TTZB77BXAVCNFSM6AAAAABHS6XZZSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGQYTONZWHE . You are receiving this because you commented.Message ID: <LuxDL/Lux. @.***>

avik-pal commented 2 months ago

using ADTypes, Lux, Random, Enzyme, Optimisers

model = Chain(Conv((3, 3), 3 => 6), GroupNorm(6, 3, gelu), Conv((3, 3), 6 => 32),
    BatchNorm(32, gelu), GlobalMeanPool(), FlattenLayer(), Dense(32, 1))

x = rand(Float32, 32, 32, 3, 4);
tstate = Lux.Experimental.TrainState(Xoshiro(0), model, Adam(0.001f0));

function obj_fn(model, ps, st, x)
    y, st_new = model(x, ps, st)
    return sum(abs2, y), st_new, (;)
end

grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
    AutoEnzyme(), obj_fn, x, tstate);

grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
    AutoEnzyme(), obj_fn, x, tstate_new);

@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate);
# 14.726 ms (461 allocations: 9.75 MiB)

@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate_new);
# 14.233 ms (447 allocations: 9.74 MiB)

Caching seems to work correctly.

avik-pal commented 2 months ago

Ok I did something wrong, it segfaulted the training test https://github.com/LuxDL/Lux.jl/actions/runs/9056705562/job/24879628489?pr=640#step:6:739

avik-pal commented 2 months ago

Locally things pass. Now we need to wait for SciMLSensitivity compats to be updated.

avik-pal commented 2 months ago

CI is not picking up on the latest SciMLSensitivity

codecov[bot] commented 2 months ago

Codecov Report

Attention: Patch coverage is 74.19355% with 16 lines in your changes are missing coverage. Please review.

Project coverage is 87.16%. Comparing base (64ba96d) to head (8bdde08).

Files Patch % Lines
src/utils.jl 26.66% 11 Missing :warning:
ext/LuxEnzymeExt.jl 86.36% 3 Missing :warning:
src/contrib/training.jl 89.47% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #640 +/- ## ========================================== - Coverage 87.56% 87.16% -0.40% ========================================== Files 49 50 +1 Lines 2380 2439 +59 ========================================== + Hits 2084 2126 +42 - Misses 296 313 +17 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.