mcabbott / SliceMap.jl

Same-same but different
Other
31 stars 3 forks source link

Difference in the implementation of slicemap and mapcols #14

Open raphaelchinchilla opened 2 years ago

raphaelchinchilla commented 2 years ago

The implementation of slicemap and mapcols is fundamentally different. Intuitively, this is relatively weird because

mapcols(f,M)=slicemap(f,M,dims=1) and slicemap(f,M,dims=1)=reshape(mapcols(f,reshape(M,size(M,1),:)),size(M)) (if dims is not equal to 1, then one could just use PermutedDimsArray)

After some (light) testing, I have the impression that using mapcols is about 25% faster than using slicemap. Is that a general result or specific to my application? Would there be any advantage on using either one or the other implementations?

mcabbott commented 2 years ago

Yes, they have quite different paths. mapcols handles everything in-house, as this was the only way I could make things work for Tracker. (As does MapCols, in a different way.)

But slicemap calls JuliennedArrays to handle the slices. The gradient rules for this will only work with Zygote. I have not investigated very closely but there may be some overhead in this.

raphaelchinchilla commented 2 years ago

Would there be any advantage on using either one or the other implementations?

After some more testing, I have concluded that in some situations one is better, in other situations the other is better. I am not sure what is the rule.

The gradient rules for this will only work with Zygote

It also works with Forward and ReverseDiff. Is that normal?

Also, a curious behavior that I have observed is that mapcols does not take the gradient of the parameters when one uses Zygote. It can be seen in this toy problem:

using SliceMap, ForwardDiff, ReverseDiff, Zygote

f(x,p)=[p*(x'*x)]

cost_slice(x,p)=sum(slicemap(x->f(x,p),x,dims=1))
cost_each(x,p)=sum(mapcols(x->f(x,p),x))

x=randn(10,100)
p=rand()

# Using slicemap

ForwardDiff.gradient(x) do x
    cost_slice(x,p)    
end

ForwardDiff.derivative(p) do p
    cost_slice(x,p)    
end

Zygote.gradient(x) do x
    cost_slice(x,p)    
end

Zygote.gradient(p) do p
    cost_slice(x,p)    
end

# Using mapcols

ForwardDiff.gradient(x) do x
    cost_each(x,p)    
end

ForwardDiff.derivative(p) do p
    cost_each(x,p)    
end

Zygote.gradient(x) do x
    cost_each(x,p)    
end

Zygote.gradient(p) do p
    cost_each(x,p)    
end
# This returns (nothing,)
mcabbott commented 2 years ago

It also works with Forward and ReverseDiff. Is that normal?

With these, this package is not involved in derivatives at all. I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.

Also, a curious behavior that I have observed is that mapcols does not take the gradient of the parameters

I was confused for a bit, but this is in fact expected. The help says:

  Any arguments after the matrix are passed to f as scalars, i.e. `mapcols(f, m, args...) =
  reduce(hcat, f(col, args...) for col in eeachcol(m))`. They do not get sliced/iterated (unlike
  map), nor are their gradients tracked.

  Note that if `f` itself contains parameters, their gradients are also not tracked.

This was enough for what I needed, I don't quite recall whether tracking or accumulating the gradient of f (which contains p) was blocked by something particular.

raphaelchinchilla commented 2 years ago

I suspect that this means ReverseDiff is tracking each number, not whole arrays, and will be quite slow, but haven't tested.

With some light testing with the functions above, it seems that ReverseDiff is about 5 times slower than Zygote.

The gradient rules for this will only work with Zygote.

Is there any technical reason not to implement them? I would gladly look into it. Or should we just hope that stack will be merged soon enough and this would be a waste of time?