Open clintonTE opened 4 years ago
Confirmed to not be due to scalar indexing:
using Flux, SparseDiffTools, BenchmarkTools, CuArrays, ForwardDiff, LinearAlgebra, Random
CuArrays.allowscalar(false)
N = 10
T = Float32
A = rand(T, N,N)
cuA = A |> gpu
function f!(out, A)
out .= A .+ A .* A .+ 1f0
end
krn(x) = x + x*x + 1f0
function f!(out, A::CuMatrix{Float32})
out .= krn.(A)
end
function f(A)
return A .+ A .* A .+ 1f0
end
function f(A::CuMatrix{Float32})
return krn.(A)
end
J = rand(T, N^2, N^2)
@info "test cpu (inplace)"
cache = SparseDiffTools.ForwardColorJacCache(f!,A, dx = similar(A))
SparseDiffTools.forwarddiff_color_jacobian!(J, f!, A, cache)
(N<5) && @info "test ∇f cpu inplace: $(J)"
(N>5) && @btime SparseDiffTools.forwarddiff_color_jacobian!($J, $f!, $A, $cache)
@info "test cpu (out of place)"
cacheoos = SparseDiffTools.ForwardColorJacCache(f,A, dx = similar(A))
J = SparseDiffTools.forwarddiff_color_jacobian(f, A, cacheoos)
(N<5) && @info "test ∇f cpu oop: $(J)"
(N>5) && @btime SparseDiffTools.forwarddiff_color_jacobian($f, $A, $cacheoos)
@info "test gpu (inplace)"
cuJ = J |> gpu
cucache = SparseDiffTools.ForwardColorJacCache(f!,cuA, dx = similar(cuA))
SparseDiffTools.forwarddiff_color_jacobian!(cuJ, f!, cuA, cucache)
(N<5) && @info "test ∇f gpu inplace: $(cuJ)"
(N>5) && @btime SparseDiffTools.forwarddiff_color_jacobian!($cuJ, $f!, $cuA, $cucache)
@info "test gpu (outofplace)"
cucacheoop = SparseDiffTools.ForwardColorJacCache(f,cuA, dx = similar(cuA))
cuJ = SparseDiffTools.forwarddiff_color_jacobian(f, cuA, cucacheoop)
(N<5) && @info "test ∇f gpu oop: $(cuJ)"
(N>5) && @btime SparseDiffTools.forwarddiff_color_jacobian($f, $cuA, $cucacheoop)
The problem is likely the fact that https://github.com/JuliaDiff/SparseDiffTools.jl/blob/v1.8.0/src/differentiation/compute_jacobian_ad.jl#L123-L126 produces too many kernels. It's somewhat tied up with https://github.com/JuliaDiff/SparseDiffTools.jl/pull/106 and the upstream issue https://github.com/JuliaGPU/CuArrays.jl/issues/571
https://github.com/JuliaDiff/SparseDiffTools.jl/pull/115 makes this much faster. I'll leave this open because it's not perfect (it now mutates), but the speed boost will essentially make this issue go away for most people.
Performance results are highly variable when using cached out of place methods. I've gotten segfaults, although I cannot reliably reproduce that portion of the issue. Using CuArrays seems to amplify the issue.
Output: