aicenter / ConditionalDists.jl

Conditional probability distributions powered by DistributionsAD.jl
MIT License
21 stars 4 forks source link

Scalar operations on GPU #17

Closed nmheim closed 4 years ago

nmheim commented 4 years ago

We have scalar operations on the CMeanVarGaussian on the GPU when splitting mean/variance in e.g. mean_var(p::CMeanVarGaussian, z). Is there a way of getting rid of this? Does it even slow things down?

vitskvara commented 4 years ago

The problem is actually in CMeanGaussian. CMeanVarGaussian seems to be fine. Disallowing scalar operations breaks down optimization,

using GPUArrays
GPUArrays.allowscalar(false)

loss() = sum(mean(p,z) .+ variance(p,z))
ps = params(p)
gs = Flux.gradient(loss, ps)

ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/vit/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
 [3] getindex(::CuArray{Float32,2,Nothing}, ::Int64) at /home/vit/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54
 [4] _getindex at ./abstractarray.jl:1004 [inlined]
 [5] getindex at ./abstractarray.jl:981 [inlined]
 [6] iterate at ./iterators.jl:237 [inlined]
 [7] #1101 at /home/vit/.julia/packages/Zygote/tJj2w/src/lib/array.jl:117 [inlined]
 [8] #2908#back at /home/vit/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [9] variance at /home/vit/.julia/dev/ConditionalDists/src/cmean_gaussian.jl:56 [inlined]
...
vitskvara commented 4 years ago

Nonscalar operations seem to speed things up more than 4x, but we have to come up with an elegant implementation that is gonna work both on gpu and cpu.

using Flux, ConditionalDists, CuArrays, BenchmarkTools, GPUArrays

function nonscalar_variance(p::CMeanGaussian{DiagVar}, z::AbstractArray)
    T = eltype(p.σ)
    σ2 = p.σ .* p.σ .+ T(1e-8)
    σ2 * CuArray(ones(T,1,size(z,2)))
end

xlen  = 3
zlen  = 2
batch = 10
T     = Float32

mapping = Dense(zlen, xlen)
var = NoGradArray(ones(T, xlen))
p  = CMeanGaussian{DiagVar}(mapping, var) |> gpu
z  = randn(T, zlen, batch) |> gpu
ps = Flux.params(p)

loss() = sum(mean(p,z) .+ variance(p,z))
nonscalar_loss() = sum(mean(p,z) .+ nonscalar_variance(p,z))

julia> @btime gs = Flux.gradient(loss, ps)
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:16
  890.547 μs (2198 allocations: 90.41 KiB)

julia> GPUArrays.allowscalar(false)

julia> @btime gs = Flux.gradient(nonscalar_loss, ps)
  208.069 μs (829 allocations: 30.92 KiB)
nmheim commented 4 years ago

so it is the repeat thats causing it?

vitskvara commented 4 years ago

Yes, it seems to be the cause.

nmheim commented 4 years ago

hmm, then we could try to use fill! in order to convert the array of ones to the correct type I guess?

vitskvara commented 4 years ago

Added https://github.com/aicenter/ConditionalDists.jl/pull/20