Closed nmheim closed 4 years ago
The problem is actually in CMeanGaussian
. CMeanVarGaussian
seems to be fine. Disallowing scalar operations breaks down optimization,
using GPUArrays
GPUArrays.allowscalar(false)
loss() = sum(mean(p,z) .+ variance(p,z))
ps = params(p)
gs = Flux.gradient(loss, ps)
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] assertscalar(::String) at /home/vit/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
[3] getindex(::CuArray{Float32,2,Nothing}, ::Int64) at /home/vit/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54
[4] _getindex at ./abstractarray.jl:1004 [inlined]
[5] getindex at ./abstractarray.jl:981 [inlined]
[6] iterate at ./iterators.jl:237 [inlined]
[7] #1101 at /home/vit/.julia/packages/Zygote/tJj2w/src/lib/array.jl:117 [inlined]
[8] #2908#back at /home/vit/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
[9] variance at /home/vit/.julia/dev/ConditionalDists/src/cmean_gaussian.jl:56 [inlined]
...
Nonscalar operations seem to speed things up more than 4x, but we have to come up with an elegant implementation that is gonna work both on gpu and cpu.
using Flux, ConditionalDists, CuArrays, BenchmarkTools, GPUArrays
function nonscalar_variance(p::CMeanGaussian{DiagVar}, z::AbstractArray)
T = eltype(p.σ)
σ2 = p.σ .* p.σ .+ T(1e-8)
σ2 * CuArray(ones(T,1,size(z,2)))
end
xlen = 3
zlen = 2
batch = 10
T = Float32
mapping = Dense(zlen, xlen)
var = NoGradArray(ones(T, xlen))
p = CMeanGaussian{DiagVar}(mapping, var) |> gpu
z = randn(T, zlen, batch) |> gpu
ps = Flux.params(p)
loss() = sum(mean(p,z) .+ variance(p,z))
nonscalar_loss() = sum(mean(p,z) .+ nonscalar_variance(p,z))
julia> @btime gs = Flux.gradient(loss, ps)
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:16
890.547 μs (2198 allocations: 90.41 KiB)
julia> GPUArrays.allowscalar(false)
julia> @btime gs = Flux.gradient(nonscalar_loss, ps)
208.069 μs (829 allocations: 30.92 KiB)
so it is the repeat
thats causing it?
Yes, it seems to be the cause.
hmm, then we could try to use fill!
in order to convert the array of ones
to the correct type I guess?
We have scalar operations on the
CMeanVarGaussian
on the GPU when splitting mean/variance in e.g.mean_var(p::CMeanVarGaussian, z)
. Is there a way of getting rid of this? Does it even slow things down?