Closed CarloLucibello closed 6 months ago
Implemented custom rule for trainable1
, with the last version of perf()
now the measurements are
trainables1
9.833 μs (54 allocations: 2.34 KiB)
trainables2
11.625 μs (94 allocations: 4.30 KiB)
trainables3
22.625 μs (189 allocations: 5.70 KiB)
gradient trainables1
29.584 μs (213 allocations: 268.50 KiB)
gradient trainables2
1.825 ms (8419 allocations: 601.59 KiB)
gradient trainables3
307.000 μs (2636 allocations: 377.53 KiB)
I'll focus on trainables1
, and pasting here the full script for future reference
using BenchmarkTools
using Optimisers
using Functors
using Zygote, Flux
using ChainRulesCore
function trainables1(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end
function ∇trainables1(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _
return Δ[i+=1]
end
end
function ChainRulesCore.rrule(::typeof(trainables1), x)
y = trainables1(x)
trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ)))
return y, trainables_back
end
############
using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
struct TrainableWalk2 <: AbstractWalk end
function (walk::TrainableWalk2)(recurse, x, ys...)
x_children = Optimisers.trainable(x)
ys_children = map(Optimisers.trainable, ys)
res = map(recurse, x_children, ys_children...)
return reduce(vcat, values(res),init=[])
end
function trainables2(x)
exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x)
return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x)
end
struct TrainableWalk3 <: AbstractWalk end
function (walk::TrainableWalk3)(recurse, x, ys...)
x_children = Optimisers.trainable(x)
ys_children = map(Optimisers.trainable, ys)
res = map(recurse, x_children, ys_children...)
return vcat(values(res)...)
end
function trainables3(x)
exclude(x) = Optimisers.isnumeric(x)
return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x)
end
function floss(ps)
sum([sum(abs2, p) for p in ps])
end
using Flux
function perf()
m = Chain(Dense(128 => 128, relu),
Dense(128 => 128, relu),
BatchNorm(128),
x -> x^2,
Dense(128 => 128, relu),
Dense(128 => 128, relu))
println("trainables1")
@btime floss(trainables1($m))
println("trainables2")
@btime floss(trainables2($m))
println("trainables3")
@btime floss(trainables3($m))
println()
println("gradient trainables1")
@btime gradient(m -> floss(trainables1(m)), $m)
println("gradient trainables2")
@btime gradient(m -> floss(trainables2(m)), $m)
println("gradient trainables3")
@btime gradient(m -> floss(trainables3(m)), $m)
nothing
end
Zygote.refresh()
perf()
could anyone review?
One API consideration before merging and releasing is whether this needs to be separate from https://github.com/FluxML/Optimisers.jl/pull/173. It's trivial to ignore the path if it's not relevant.
Also, while okay for Functors, I'm not a fan of duplicated *_with_path
functions. Like @ToucheSir suggested, if we want two variants, then a keyword flag seems like a better API.
I kept #173 separate to simplify the review of this one. Let's continue the discussion there.
An alternative to #57 adding the
trainables
method that returns a vector of arrays.I'm playing with different implementations at the moment. The output of the
perf()
function istrainables1
is the fastest but since it is mutating it needs a customrrule
for differentiation. Probably the rrule fordestructor
can be adapted for this case https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jlThe gradients of the other two implementations are very slow, also in these cases we would need a custom rule.
TODO
fmap
)