CTUAvastLab / Mill.jl

Build flexible hierarchical multi-instance learning models.
https://ctuavastlab.github.io/Mill.jl/stable/
MIT License
86 stars 8 forks source link

Pevnak/generated #96

Closed pevnak closed 2 years ago

pevnak commented 2 years ago

This adds a generated function to efficiently deal with the case, when the keys in NamedTuples of ProductNode and ProductModel do not match. When they match, the old method relying on map still works. If they do not match, there is a generated function which generated efficient code.

Test bench:

using Mill, Flux
a = ArrayNode(randn(Float32, 3, 4))
b = ArrayNode(randn(Float32, 3, 4))
c = ArrayNode(randn(Float32, 3, 4))
x₁ = ProductNode((;a, b, c))
m = reflectinmodel(x₁, d -> Dense(d, 4))
x₂ = ProductNode((;c, b, a))
m(x₁).data ≈ m(x₂).data

The generated function generates code, which is much nicer on zygote. Compile time of a version based on map

ps = Flux.params(m)
julia> @elapsed gradient(() -> sum(m(x₁).data), ps)
27.183408771

Compile time of function that uses generated function

julia> @elapsed gradient(() -> sum(m(x₂).data), ps) 
25.148775435

thus we see what has been expected.

Regarding the runtime, both functions seems to be the same (recall x₁ calls map, x₂ calls generated function)

julia> using BenchmarkTools

julia> @btime m(x₁)
  1.241 μs (13 allocations: 1.70 KiB)

julia> @btime m(x₂)
  1.274 μs (13 allocations: 1.70 KiB)

though during the gradient, the generated function is more efficient

 julia> @btime gradient(() -> sum(m(x₁).data), ps)
  202.678 μs (917 allocations: 58.92 KiB)
Grads(...)

julia> @btime gradient(() -> sum(m(x₂).data), ps)
  168.142 μs (825 allocations: 52.62 KiB)
Grads(...)

threfore I leave up to the committee (@SimonMandlik @racinmat ) if we should try the generated function by default.