Open raphaelchinchilla opened 2 years ago
Yes, they have quite different paths. mapcols
handles everything in-house, as this was the only way I could make things work for Tracker. (As does MapCols
, in a different way.)
But slicemap
calls JuliennedArrays to handle the slices. The gradient rules for this will only work with Zygote. I have not investigated very closely but there may be some overhead in this.
Would there be any advantage on using either one or the other implementations?
After some more testing, I have concluded that in some situations one is better, in other situations the other is better. I am not sure what is the rule.
The gradient rules for this will only work with Zygote
It also works with Forward and ReverseDiff. Is that normal?
Also, a curious behavior that I have observed is that mapcols
does not take the gradient of the parameters when one uses Zygote. It can be seen in this toy problem:
using SliceMap, ForwardDiff, ReverseDiff, Zygote
f(x,p)=[p*(x'*x)]
cost_slice(x,p)=sum(slicemap(x->f(x,p),x,dims=1))
cost_each(x,p)=sum(mapcols(x->f(x,p),x))
x=randn(10,100)
p=rand()
# Using slicemap
ForwardDiff.gradient(x) do x
cost_slice(x,p)
end
ForwardDiff.derivative(p) do p
cost_slice(x,p)
end
Zygote.gradient(x) do x
cost_slice(x,p)
end
Zygote.gradient(p) do p
cost_slice(x,p)
end
# Using mapcols
ForwardDiff.gradient(x) do x
cost_each(x,p)
end
ForwardDiff.derivative(p) do p
cost_each(x,p)
end
Zygote.gradient(x) do x
cost_each(x,p)
end
Zygote.gradient(p) do p
cost_each(x,p)
end
# This returns (nothing,)
It also works with Forward and ReverseDiff. Is that normal?
With these, this package is not involved in derivatives at all. I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.
Also, a curious behavior that I have observed is that
mapcols
does not take the gradient of the parameters
I was confused for a bit, but this is in fact expected. The help says:
Any arguments after the matrix are passed to f as scalars, i.e. `mapcols(f, m, args...) =
reduce(hcat, f(col, args...) for col in eeachcol(m))`. They do not get sliced/iterated (unlike
map), nor are their gradients tracked.
Note that if `f` itself contains parameters, their gradients are also not tracked.
This was enough for what I needed, I don't quite recall whether tracking or accumulating the gradient of f
(which contains p
) was blocked by something particular.
I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.
With some light testing with the functions above, it seems that ReverseDiff is about 5 times slower than Zygote.
The gradient rules for this will only work with Zygote.
Is there any technical reason not to implement them? I would gladly look into it. Or should we just hope that stack
will be merged soon enough and this would be a waste of time?
The implementation of slicemap and mapcols is fundamentally different. Intuitively, this is relatively weird because
mapcols(f,M)=slicemap(f,M,dims=1)
andslicemap(f,M,dims=1)=reshape(mapcols(f,reshape(M,size(M,1),:)),size(M))
(ifdims
is not equal to 1, then one could just use PermutedDimsArray)After some (light) testing, I have the impression that using
mapcols
is about 25% faster than using slicemap. Is that a general result or specific to my application? Would there be any advantage on using either one or the other implementations?