FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 209 forks source link

Differentiable tmap? #743

Open ChrisRackauckas opened 4 years ago

ChrisRackauckas commented 4 years ago

Where should this go?

using Zygote

function tmap(f,args...)
  batch_data = Vector{Any}(undef,length(args[1]))
  Threads.@threads for i in 1:length(args[1])
      batch_data[i] = f(getindex.(args,i)...)
  end
  map(identity,reduce(vcat,batch_data)) # reduce and tighten eltype
end

function ∇tmap(cx, f, args...)
  ys_and_backs = tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
  if isempty(ys_and_backs)
    ys_and_backs, _ -> nothing
  else
    ys, backs = Zygote.unzip(ys_and_backs)
    function ∇tmap_internal(Δ)
      lengths = vcat(1,cumsum([length(ys[i]) for i in 1:length(ys)]))
      Δ_split = [Δ[lengths[i]:lengths[i+1]] for i in 1:length(ys)]
      Δf_and_args_zipped = DiffEqBase.tmap((f, δ) -> f(δ), backs, Δ_split)
      Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
      Δf = reduce(Zygote.accum, Δf_and_args[1])
      (Δf, Δf_and_args[2:end]...)
    end
    reduce(vcat,ys),∇tmap_internal
  end
end

Zygote.@adjoint function tmap(f, args::Union{AbstractArray,Tuple}...)
  ∇tmap(__context__, f, args...)
end

Note that it has a few extra tidbits over the normal one in order to handle the reduction.

racinmat commented 4 years ago

I guess that it should go somewhere similarily to https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L183 ? Although I'm thinking it it wouldn't be better to use tmap from e.g. https://github.com/baggepinnen/ThreadTools.jl which is a little more flexible.

ChrisRackauckas commented 4 years ago

Those over spawn.

DhairyaLGandhi commented 4 years ago

Presumably its already defined in a package somewhere?

racinmat commented 4 years ago

Is it? It would be great to have some list of packages that add Zygote support for various functionalities, because otherwise I'm a bit lost and performing random walk over packages that depend on Zygote is a bit time consuming :D