Closed AzamatB closed 4 years ago
Yes there's no real need to restrict f
.
When f
contains parameters, I'm not sure whether mapcols(f, M)
will take the gradient with respect to f
correctly. But for slicemap
in particular I guess it should work. What's the simplest possible test of this?
I would like to use it in the forward pass of the BLSTM, so my use case is something like:
using Flux
using SliceMap
m = LSTM(3,2) |> gpu
xs = rand(Float32, 3,7,4) |> gpu
@test slicemap(m, xs; dims=2) == reduce(hcat, map(x -> reshape(m(x), 2,1,4), eachslice(xs; dims=2)))
Is this simple enough?
But that only checks the forward pass (and it depends on Flux). I meant a check that the gradient works, something like this perhaps, and compare to ForwardDiff?
struct F W end
(f::F)(x) = f.W * x
w = ones(2,2)
x = rand(2,3)
g = gradient(() -> sum(mapcols(F(w), x)), Params([w,x]))
In
slicemap(f, A; dims)
f
is restricted toFunction
type. This prevents from passing to itFlux
layers. Can we fix this?