Closed simonmandlik closed 1 year ago
That the old dataloader worked is probably a happy accident. Is there any reason you can't pull a batch from the dataloader (e.g. extract first(mbs)
outside of lambda) before calling gradient
? Even if it did work before, it was likely causing unncessary performance issues because Zygote has to differentiate through the DataLoader code.
No, not really. The original idea was to make code like this work out of the box with highest-level API (like Flux.train!
, but looking at its code it also loops over minibatches and computes gradient for each separately).
Thanks!
Due to the
try/catch
in the implementation ofDataLoader
,Zygote.jl
cannot differentiate through the iteration:But similar code worked in old
MLDataPattern
: