aicenter / ConditionalDists.jl

Conditional probability distributions powered by DistributionsAD.jl
MIT License
21 stars 4 forks source link

Change the mean var splitting for more than 2 dimensions #4

Closed vitskvara closed 4 years ago

vitskvara commented 4 years ago

Currently, the split only works for 2D outputs.

using Flux
using ConditionalDists

xlen = (4, 4, 1) 
zlen = 2
batch = 10
T = Float32
x = randn(T, xlen..., batch)
p = CMeanVarGaussian{T,DiagVar}(f32(Conv((3,3), xlen[3]=>xlen[3]*2)))
y = mean(p, x)
ERROR: BoundsError: attempt to access 2×2×2×10 Array{Float32,4} at index [1:1, Base.Slice(Base.OneTo(2))]
Stacktrace:
 [1] throw_boundserror(::Array{Float32,4}, ::Tuple{UnitRange{Int64},Base.Slice{Base.OneTo{Int64}}}) at ./abstractarray.jl:538
 [2] checkbounds at ./abstractarray.jl:503 [inlined]
 [3] _getindex at ./multidimensional.jl:669 [inlined]
 ...
vitskvara commented 4 years ago

I am a little lost - currently the only working method I was able to come up with was to write methods like

half_split(X::AbstractArray{T,2}) where T = X[1:Int(size(X,1)/2),:], X[1+Int(size(X,1)/2):end,:]
half_split(X::AbstractArray{T,4}) where T = X[:,:,1:Int(size(X,3)/2),:], X[:,:,1+Int(size(X,3)/2):end,:]
...

This seems rather clumsy, but a general approach for arbitrary dimension, where the last-but-one dimension is used for splitting, seems to not be backpropagable

using GenerativeModels
using Flux

function half_split(X::AbstractArray)
    # first get the dimension to be split
    xlen = Int(size(X, ndims(X)-1)/2)

    # now get the axes
    axs1 = axs2 = [collect(ax) for ax in axes(X)]
    axs1[end-1] = 1:xlen
    axs2[end-1] = xlen+1:xlen*2

    X[axs1...], X[axs2...]
end

l1 = Dense(4,4)
l2 = Dense(2,4)

x = randn(4,6)
function l(x) 
    a,b = half_split(l1(x))
    Flux.mse(x, l2(a+b))
end

data = (x,)
opt = ADAM()
@time GenerativeModels.update_params!(l1, data, l, opt)
ERROR: MethodError: no method matching setindex!(::Tuple{Array{Int64,1},Array{Int64,1}}, ::UnitRange{Int64}, ::Int64)
Stacktrace:
 [1] macro expansion at /home/vit/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::typeof(setindex!), ::Tuple{Array{Int64,1},Array{Int64,1}}, ::UnitRange{Int64}, ::Int64) at /home/vit/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:7
 [3] _pullback(::Zygote.Context, ::typeof(half_split), ::Array{Float32,2}) at ./REPL[13]:7
 [4] l at ./REPL[17]:2 [inlined]
...
vitskvara commented 4 years ago

Closing this as it is irrelevant due to https://github.com/aicenter/ConditionalDists.jl/pull/5