Open baggepinnen opened 4 years ago
This method definition gets around the problem, not sure if it's a hack though
function CuArray(x::Base.ReshapedArray{<:Any, <:Any, <:Adjoint})
xp = CuArray(x.parent)
ra = Base.ReshapedArray(xp,x.dims, x.mi)
CuArray(ra)
end
Another instance from slack, where a Reshape of a Transpose doesn't broadcast properly. Workaround:
julia> GPUArrays.BroadcastStyle(::Type{<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Transpose{<:Any,AT},<:Any}}) where {AT<:GPUArray} = GPUArrays.BroadcastStyle(AT)
julia> GPUArrays.backend(::Type{<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Transpose{<:Any,AT},<:Any}}) where {AT<:GPUArray} = GPUArrays.backend(AT)
julia> vec(transpose(cu(rand(2,2)))) .+ 1
4-element CuArray{Float32,1}:
1.974818
1.2322885
1.6275826
1.213689
IIUC we might need something like https://github.com/JuliaLang/julia/pull/31563 to deal with this in a more profound way?
Could something like this make sense? Can be extended with other known wrappers
const Wrapper = Union{Base.ReshapedArray, LinearAlgebra.Transpose, LinearAlgebra.Adjoint}
function CuArray(x::Base.ReshapedArray{<:Any, <:Any, <:Wrapper})
xp = CuArray(x.parent)
ra = Base.ReshapedArray(xp,x.dims, x.mi)
CuArray(ra)
end
function CuArray(x::LinearAlgebra.Transpose{<:Any, <:Wrapper})
xp = CuArray(x.parent)
ra = Base.Transpose(xp)
CuArray(ra)
end
function CuArray(x::LinearAlgebra.Adjoint{<:Any, <:Wrapper})
xp = CuArray(x.parent)
ra = Base.Adjoint(xp)
CuArray(ra)
end
Details on Julia:
I hit this error all the time while backpropagating using Tracker, not sure which adjoint definition causes this error though.
Both the adjoint and the reshape are required for the error to appear, either one of them by themselves works alright.