mcabbott / SliceMap.jl

Same-same but different
Other
31 stars 3 forks source link

In slicemap(f, A; dims) f is restricted to Function type #6

Closed AzamatB closed 4 years ago

AzamatB commented 4 years ago

In slicemap(f, A; dims) f is restricted to Function type. This prevents from passing to it Flux layers. Can we fix this?

mcabbott commented 4 years ago

Yes there's no real need to restrict f.

When f contains parameters, I'm not sure whether mapcols(f, M) will take the gradient with respect to f correctly. But for slicemap in particular I guess it should work. What's the simplest possible test of this?

AzamatB commented 4 years ago

I would like to use it in the forward pass of the BLSTM, so my use case is something like:

using Flux
using SliceMap

m = LSTM(3,2) |> gpu
xs = rand(Float32, 3,7,4) |> gpu

@test slicemap(m, xs; dims=2) == reduce(hcat, map(x -> reshape(m(x), 2,1,4), eachslice(xs; dims=2)))

Is this simple enough?

mcabbott commented 4 years ago

But that only checks the forward pass (and it depends on Flux). I meant a check that the gradient works, something like this perhaps, and compare to ForwardDiff?

struct F W end
(f::F)(x) = f.W * x
w = ones(2,2)
x = rand(2,3)
g = gradient(() -> sum(mapcols(F(w), x)), Params([w,x]))