Closed maleadt closed 4 years ago
There are two quite different cases that come up here: (1), you want to move all the arrays in a collection / data structure to the GPU, and (2) you want to move the "whole object" to the GPU (as in your example).
(1) is likely what the users in that thread want, since you typically create a series of (flat) GPU arrays as data points in Flux, and iterate over them on the CPU (i.e. Vector{CuVector{Float32}}
. That's also what we provide mapleaves(cu, x)
for, so that seems like it's covered.
I think cu
should stick to working with the "single object" interpretation, (2), which in this case means producing the nested CuArray
as you describe or otherwise erroring out. My main reservation is that it might be confusing or slow for people who actually want (1) (which IME so far is 100% of the time).
Tuples are a bit weird, but I'd argue that Tuple{CuArray}
can be interpreted both as (1) "cpu tuple of gpu arrays" and (2) "gpu tuple of arrays", which makes it legitimate that both cu
and mapleaves(cu)
do the same thing in that case. In other cases – like arrays of arrays – they won't behave the same.
Right, I didn't realize Flux/users wanted Array{CuArray}
so was only considering the case where you'd get an CuArray{CuArray}
. In that case, the existing mapleaves
abstraction seems fine. Having cu
do that automatically seems confusing to me, because e.g. the behavior of adapting a container would then depend on the adaptability of its elements. I'll open an issue on CuArrays to improve the error message.
Users seem to expect Adapt to work recursively, e.g. https://discourse.julialang.org/t/using-vectors-of-vectors-with-cuarrays/, to adapt arrays of arrays:
Since we do already treat tuples recursively, e.g.
cu(([1],))
yields a tuple of CuArray, maybe we should also peek into arrays? It gets hairy real quickly though, so maybe just support arrays of arrays?Thoughts? @vchuravy @MikeInnes