Closed vitskvara closed 4 years ago
I am a little lost - currently the only working method I was able to come up with was to write methods like
half_split(X::AbstractArray{T,2}) where T = X[1:Int(size(X,1)/2),:], X[1+Int(size(X,1)/2):end,:]
half_split(X::AbstractArray{T,4}) where T = X[:,:,1:Int(size(X,3)/2),:], X[:,:,1+Int(size(X,3)/2):end,:]
...
This seems rather clumsy, but a general approach for arbitrary dimension, where the last-but-one dimension is used for splitting, seems to not be backpropagable
using GenerativeModels
using Flux
function half_split(X::AbstractArray)
# first get the dimension to be split
xlen = Int(size(X, ndims(X)-1)/2)
# now get the axes
axs1 = axs2 = [collect(ax) for ax in axes(X)]
axs1[end-1] = 1:xlen
axs2[end-1] = xlen+1:xlen*2
X[axs1...], X[axs2...]
end
l1 = Dense(4,4)
l2 = Dense(2,4)
x = randn(4,6)
function l(x)
a,b = half_split(l1(x))
Flux.mse(x, l2(a+b))
end
data = (x,)
opt = ADAM()
@time GenerativeModels.update_params!(l1, data, l, opt)
ERROR: MethodError: no method matching setindex!(::Tuple{Array{Int64,1},Array{Int64,1}}, ::UnitRange{Int64}, ::Int64)
Stacktrace:
[1] macro expansion at /home/vit/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::typeof(setindex!), ::Tuple{Array{Int64,1},Array{Int64,1}}, ::UnitRange{Int64}, ::Int64) at /home/vit/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:7
[3] _pullback(::Zygote.Context, ::typeof(half_split), ::Array{Float32,2}) at ./REPL[13]:7
[4] l at ./REPL[17]:2 [inlined]
...
Closing this as it is irrelevant due to https://github.com/aicenter/ConditionalDists.jl/pull/5
Currently, the split only works for 2D outputs.