JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 20 forks source link

DataLoader incompatible with Flux/Zygote #127

Closed simonmandlik closed 1 year ago

simonmandlik commented 1 year ago

Due to the try/catch in the implementation of DataLoader, Zygote.jl cannot differentiate through the iteration:

using Flux, MLUtils

x = randn(10, 10)
m = Dense(10, 10)
ps = Flux.params(m)
mbs = MLUtils.DataLoader(x, batchsize=4, shuffle=true)

julia> mb_grad = gradient(() -> sum(m(first(mbs))), ps)
ERROR: Compiling Tuple{MLUtils.var"##BatchView#28", Int64, Bool, Val{nothing}, Type{BatchView}, ObsView{Matrix{Float64}, Vector{Int64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:121
  [3] #Primal#23
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/emit.jl:101
  [6] #s2924#1068
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s2924#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/.julia/packages/MLUtils/Th9Y3/src/batchview.jl:92 [inlined]
 [10] _pullback(::Zygote.Context{true}, ::Core.var"#Type##kw", ::NamedTuple{(:batchsize, :partial, :collate), Tuple{Int64, Bool, Val{nothing}}}, ::Type{BatchView}, ::ObsView{Matrix{Float64}, Vector{Int64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/MLUtils/Th9Y3/src/eachobs.jl:161 [inlined]
 [12] _pullback(ctx::Zygote.Context{true}, f::typeof(iterate), args::DataLoader{Matrix{Float64}, Random._GLOBAL_RNG, Val{nothing}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./abstractarray.jl:424 [inlined]
 [14] _pullback(ctx::Zygote.Context{true}, f::typeof(first), args::DataLoader{Matrix{Float64}, Random._GLOBAL_RNG, Val{nothing}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./REPL[10]:1 [inlined]
 [16] _pullback(::Zygote.Context{true}, ::var"#5#6")
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:373
 [18] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [19] top-level scope
    @ REPL[10]:1
 [20] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

(jl_PGRlDC) pkg> st
Status `/private/var/folders/pr/bshcrht5423cbhtz_ls_h_7r0000gp/T/jl_PGRlDC/Project.toml`
  [587475ba] Flux v0.13.6
  [f1d291b0] MLUtils v0.2.11

But similar code worked in old MLDataPattern:

using Flux, MLDataPattern

x = randn(10, 10)
m = Dense(10, 10)
ps = Flux.params(m)
mbs = RandomBatches(x, size=4)

julia> mb_grad = gradient(() -> sum(m(first(mbs))), ps)
Grads(...)

(jl_NW4gr7) pkg> st
Status `/private/var/folders/pr/bshcrht5423cbhtz_ls_h_7r0000gp/T/jl_NW4gr7/Project.toml`
  [587475ba] Flux v0.13.6
  [9920b226] MLDataPattern v0.5.5
ToucheSir commented 1 year ago

That the old dataloader worked is probably a happy accident. Is there any reason you can't pull a batch from the dataloader (e.g. extract first(mbs) outside of lambda) before calling gradient? Even if it did work before, it was likely causing unncessary performance issues because Zygote has to differentiate through the DataLoader code.

simonmandlik commented 1 year ago

No, not really. The original idea was to make code like this work out of the box with highest-level API (like Flux.train!, but looking at its code it also loops over minibatches and computes gradient for each separately).

Thanks!