FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

Add in-place `destructure!` #165

Open mcabbott opened 10 months ago

mcabbott commented 10 months ago

This adds a variant of destructure with minimal changes such that it writes back into the original model, instead of creating a copy. This may close #146, cc @glatteis

Marked draft as it seems surprisingly slow -- why?

julia> model = Chain(Dense(28^2 => 1024, relu), Dense(1024 => 10));

julia> params, re = destructure(model);  # old

julia> params, re! = destructure!(model);  # new

julia> @btime $re($params);  # This is the reconstruction cost
  min 229.334 μs, mean 374.473 μs (70 allocations, 3.11 MiB)

julia> @btime copy($params);  # ... and it's mostly allocation, same mean:
  min 219.417 μs, mean 367.168 μs (3 allocations, 3.11 MiB)

julia> @btime $re!($params);  # this avoids the allocations, but is quite slow.
  min 432.917 μs, mean 472.293 μs (58 allocations, 2.02 KiB)
linusheck commented 10 months ago

Thanks you! Weird that it is so slow.

ToucheSir commented 10 months ago

Could the in-place version be tripping some aliasing heuristic and hitting a slow path? I guess a profile would be illuminating.

kishore-nori commented 4 months ago

Thank you initiating and implementing this idea, I think this is a great idea and would be very useful, I was trying this out because I am interested in-place copy of parameters into a model from a flat vector. From my comparisons, I suspect that one of the reason for the slowness of the in-place version is cache issues involving copyto! (simultaneously handling two big memories) which doesn't occur in the usual restructure, this can be observed when we have a relatively smaller model.

Additionally the following version (just utilising the fact that copyto! doesn't need reshape, as memory based) is faster in my tests,

function _rebuild_alt!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
  len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
  fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
    vecy = vec(y)
    copyto!(y, _getat_alt(vecy, o, flat, view))
  end
  x
end

_getat_alt(y::AbstractVector, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(get(flat, o .+ (1:length(y))))

and its get better than the usual restructure when the model get smaller, which probably directs at cache issues for copyto!

using Flux, Optimisers, Zygote, BenchmarkTools

N = 1024
model = Chain(Dense(28^2 => N, relu), Dense(N => 10));

params,re = destructure(model)
params!,re! = destructure!(model)
params_alt!,re_alt! = destructure_alt!(model) # using above alternatives

@btime $re($params)
@btime $re!($params)
@btime $re_alt!($params)

  106.964 μs (44 allocations: 3.11 MiB)

  250.546 μs (35 allocations: 1.53 KiB)

  156.664 μs (39 allocations: 1.69 KiB)

When I choose $N = 100$ I get the following timings:

  12.184 μs (43 allocations: 312.61 KiB)

  21.374 μs (35 allocations: 1.53 KiB)

  7.651 μs (39 allocations: 1.69 KiB)
mcabbott commented 4 months ago

Ah that looks great, thanks for digging! For me, with the example at top:

julia> @btime $re($params);  # This is the reconstruction cost
  min 92.167 μs, mean 301.432 μs (44 allocations, 3.11 MiB)

julia> @btime copy($params);  # ... and it's mostly allocation, same mean:
  min 97.333 μs, mean 309.699 μs (2 allocations, 3.11 MiB)

julia> @btime $re!($params);  # new version without reshape
  min 58.333 μs, mean 62.932 μs (39 allocations, 1.69 KiB)

and with N=100:

julia> @btime $re($params);
  min 7.167 μs, mean 25.760 μs (43 allocations, 312.67 KiB)

julia> @btime copy($params);
  min 4.944 μs, mean 31.490 μs (2 allocations, 310.67 KiB)

julia> @btime $re!($params);
  min 8.812 μs, mean 9.047 μs (39 allocations, 1.69 KiB)

I think the mean times are probably a better indication of the cost in actual use, when allocations differ so much, although possibly not perfect.

kishore-nori commented 4 months ago

Nice, that's good to know, the in-place version seems to be pretty stable with the timings, and how do I make @btime output both min and mean like you have? (sorry for going off-topic).

And is the PR good to go?

mcabbott commented 4 months ago

Mean is from https://github.com/JuliaCI/BenchmarkTools.jl/pull/258, which I should eventually re-write to @btime / @btimes or @bmin / @btime or something.

I see I did write some tests of this, it could all use one more look over. There's a commented out destructure!(flat::AbstractVector, x) which would allow the other direction not to allocate a whole copy too... not sure if that's almost ready or another whole project.

kishore-nori commented 4 months ago

Mean is from https://github.com/JuliaCI/BenchmarkTools.jl/pull/258

Interesting! That would be very useful! I ll see if I can take some pirate code out of it to start using locally :)

I didn't take a look at destructure!(flat::AbstractVector, x) earlier, but looking at it now, I think it seems good, just that the length compatibility check needs to be done post all the copyto! to flat is done or in the last iteration. Can we do it before using Flux.State (earlier I was using Flux.Params for these purposes; but these parse over all the layers to make it happen, so there is a redundancy :/). Another thing is the type checking for flat, is it required here? Anyways, I think for most purposes destructure! is used just once (I can't think of cases other than where there is pruning involved), so even if destructure!(flat::AbstractVector, x) is not ready, the rest should be useful in its own right, I think. But is there a reason on why it was commented out? I can try looking into it..

mcabbott commented 4 months ago

I commented the method out just to focus on getting one thing working first. I believe it still needs tests, but otherwise this is nearly done.

Maybe I should check that my scary warning is true. I think something like this will return zero gradient:

v, re! = destructure!(model)
gradient(v) do w
  _ = re!(w)  # mutates model, in a way Zygote does not see
  sum(abs2, model(x))
end
kishore-nori commented 4 months ago

That's great!

Yes, you are right it returns (nothing,), _ breaks the connection in the chain of rrules.