FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
75 stars 22 forks source link

add `trainables` #171

Closed CarloLucibello closed 6 months ago

CarloLucibello commented 6 months ago

An alternative to #57 adding the trainables method that returns a vector of arrays.

I'm playing with different implementations at the moment. The output of the perf() function is

trainables1
  1.717 μs (33 allocations: 1.48 KiB)
trainables2
  2.708 μs (63 allocations: 3.12 KiB)
trainables3
  11.208 μs (147 allocations: 4.39 KiB)

gradient trainables2
  1.546 ms (7157 allocations: 304.39 KiB)
gradient trainables3
  249.625 μs (2289 allocations: 115.39 KiB)

trainables1 is the fastest but since it is mutating it needs a custom rrule for differentiation. Probably the rrule for destructor can be adapted for this case https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl

The gradients of the other two implementations are very slow, also in these cases we would need a custom rule.

TODO

CarloLucibello commented 6 months ago

Implemented custom rule for trainable1, with the last version of perf() now the measurements are

trainables1
  9.833 μs (54 allocations: 2.34 KiB)
trainables2
  11.625 μs (94 allocations: 4.30 KiB)
trainables3
  22.625 μs (189 allocations: 5.70 KiB)

gradient trainables1
  29.584 μs (213 allocations: 268.50 KiB)
gradient trainables2
  1.825 ms (8419 allocations: 601.59 KiB)
gradient trainables3
  307.000 μs (2636 allocations: 377.53 KiB)
CarloLucibello commented 6 months ago

I'll focus on trainables1, and pasting here the full script for future reference

using BenchmarkTools
using Optimisers
using Functors
using Zygote, Flux
using ChainRulesCore

function trainables1(x)
    arrays = AbstractArray[]
    exclude(x) = Optimisers.isnumeric(x)
    fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y
        push!(arrays, y)
        return y
    end
    return arrays
end

function ∇trainables1(x, Δ)
    exclude(x) = Optimisers.isnumeric(x)
    i = 0
    return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _
                return Δ[i+=1]
           end
end

function ChainRulesCore.rrule(::typeof(trainables1), x)
    y = trainables1(x)
    trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ)))
    return y, trainables_back
end

############

using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk

struct TrainableWalk2 <: AbstractWalk end

function (walk::TrainableWalk2)(recurse, x, ys...)
    x_children = Optimisers.trainable(x)
    ys_children = map(Optimisers.trainable, ys)
    res = map(recurse, x_children, ys_children...)
    return reduce(vcat, values(res),init=[])
end

function trainables2(x)
    exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
    return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x)
end

struct TrainableWalk3 <: AbstractWalk end

function (walk::TrainableWalk3)(recurse, x, ys...)
    x_children = Optimisers.trainable(x)
    ys_children = map(Optimisers.trainable, ys)
    res = map(recurse, x_children, ys_children...)
    return vcat(values(res)...)
end

function trainables3(x)
    exclude(x) = Optimisers.isnumeric(x)
    return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x)
end

function floss(ps)
    sum([sum(abs2, p) for p in ps])
end

using Flux

function perf()
    m = Chain(Dense(128 => 128, relu), 
              Dense(128 => 128, relu),
              BatchNorm(128),
              x -> x^2,
              Dense(128 => 128, relu), 
              Dense(128 => 128, relu))

    println("trainables1")
    @btime floss(trainables1($m))
    println("trainables2")
    @btime floss(trainables2($m))
    println("trainables3")
    @btime floss(trainables3($m))
    println()

    println("gradient trainables1")
    @btime gradient(m -> floss(trainables1(m)), $m)
    println("gradient trainables2")
    @btime gradient(m -> floss(trainables2(m)), $m)
    println("gradient trainables3")
    @btime gradient(m -> floss(trainables3(m)), $m)

    nothing
end

Zygote.refresh()
perf()
CarloLucibello commented 6 months ago

could anyone review?

darsnack commented 6 months ago

One API consideration before merging and releasing is whether this needs to be separate from https://github.com/FluxML/Optimisers.jl/pull/173. It's trivial to ignore the path if it's not relevant.

Also, while okay for Functors, I'm not a fan of duplicated *_with_path functions. Like @ToucheSir suggested, if we want two variants, then a keyword flag seems like a better API.

CarloLucibello commented 6 months ago

I kept #173 separate to simplify the review of this one. Let's continue the discussion there.