FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

GPU kernels for optimizers #178

Open vpuri3 opened 2 weeks ago

vpuri3 commented 2 weeks ago

Motivation and description

Wondering what kind of speedup can be achieved by writing GPU kernels for optimizers.

Take a look at @pxl-th's implementation of Adam below

https://github.com/JuliaNeuralGraphics/NerfUtils.jl/blob/main/src/nn/adam.jl#L100-L117

Possible Implementation

No response

pxl-th commented 2 weeks ago

The kernel in NerfUtils.jl fuses several operations into a single kernel, while Optimisers split it up into 4 (if counting actual parameter update).

For smaller arrays the benefit is negligible, but for something like ~400+ MB it is around ~2x faster.

MWE:

using AMDGPU
using BenchmarkTools
using KernelAbstractions
using Flux
using NerfUtils

function main()
    x = AMDGPU.rand(Float32, 100 * 1024^2)
    dx = AMDGPU.rand(Float32, 100 * 1024^2)

    kab = get_backend(x)

    opt_1 = NerfUtils.Adam(kab, x)
    opt_2 = Flux.Optimisers.Adam()
    state = Flux.Optimisers.init(opt_2, x)

    @btime AMDGPU.@sync NerfUtils.step!($opt_1, $x, $dx; dispose=false)

    @btime AMDGPU.@sync begin
        ns, nx = Flux.Optimisers.apply!($opt_2, $state, $x, $dx)
        $x .-= nx
    end
    return
end

Timings:

julia> main()
  6.168 ms (395 allocations: 10.13 KiB)
  13.161 ms (339 allocations: 9.09 KiB)
ToucheSir commented 2 weeks ago

The reason Optimisers.jl rules are written the way they are is because we have to balance a few things. To demonstrate them, let's look at the implementation of Adam: https://github.com/FluxML/Optimisers.jl/blob/c2ae321518b2948dc56af3357f6a206b511c7b3e/src/rules.jl#L219-L221

  1. Broad array type compatibility: that @.. macro is not a typo. It's actually a custom version of @. which will write in-place where possible and return a new array for immutable array types. Custom kernels only work with mutable array types, and a limited number of them at that.
  2. Deferring work where possible: "...while Optimisers split it up into 4 (if counting actual parameter update)" is off by one, because the @lazy means dx′ is a Broadcasted instead of a materialized array. We do this to ensure better fusion with subsequent steps (think how AdamW does Adam + WeightDecay), as well as fusion with the final parameter update step. Writing a standalone GPU kernel for each AbstractRule would mean we lose out on this fusion.
  3. Legibility and ease of entry: most people who contribute rules to Optimisers.jl are not super familar with writing GPU code. Our current system for writing rules seems to be pretty accessible, since most of the work is translating statements of math -> statements of array-level Julia code. Unless we want to make follow-up PRs for GPU kernels every time a new rule is added, we'll want a system which lowers the barrier to entry somehow.

Those are the constraints. My thoughts on where to proceed are that we need a design which addresses most of them. Priority-wise, my ranking would be 1) maintaining laziness when rules are composed, 2) maximizing code reuse with non GPU arrays, and 3) lowering the barrier of entry so people don't have to understand all of KernelAbstractions to get started. This all seems doable, but would require a champion to flesh out a design and push it.