FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 210 forks source link

Improve movement of Grads #978

Open DhairyaLGandhi opened 3 years ago

DhairyaLGandhi commented 3 years ago

Grads stores within it some references and the calculated gradients. That's nice. When training large distributed GPU models, there are many times that one may want to materialise it in order to move it to a separate GPU on a separate node to synchronise and later update to the model. This is made difficult since we currently do not respect the dict iteration interface which gives access to both the keys and values. The values by themselves are hard to reason about, since there is actually no guarantee of order. Another thing is that when one does any such operation with gradients, it is usually required to also have the parameters information attached. One would typically need to move both the params and the grads to separate devices, rarely either or.

There is a copy!(Vector, Grads) method, which also doesn't by itself solve the issue. It is easy to go from multiple arrays to a vector, but typically harder to go from a vector to multiple arrays, especially if one needs to maintain their own size information to recreate the gradients back. Removing that seems the most logical step.

ToucheSir commented 3 years ago

Is this not an inherent limitation of implicit params, since parameter IDs will not be identical across workers?

DhairyaLGandhi commented 3 years ago

No. This can follow the same pathway as the regular data migration, but rather than pass the reference back, we need to pass the copy. The bottleneck is not using the existing Base implementation, and of course catching the iddict.

darsnack commented 3 years ago

This doesn't address everything laid out here, but you can already iterate Grads by kv pairs:

julia> using Zygote

julia> w, x1, x2, b = rand(2), rand(2), rand(2), rand(2);

julia> gs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))
Grads(...)

julia> for (k, v) in pairs(gs1)
       @show k, v
       end
(k, v) = ([0.6482442102960793, 0.016523264180973163], [0.24030749679167682, 0.49598729716935896])
(k, v) = ([0.6918785001353205, 0.8487855432074936], [0.2416676768164966, 0.5117878144292793])
ToucheSir commented 3 years ago

I think the bigger fundamental problem is knowing where the parameters slot into the model. Having a mechanism that is stable when you send them over the wire would help with that. Currently Params leans very heavily on object identity and thus stymies this somewhat. Perhaps if it had something like order but with indices instead.