Open nathanaelbosch opened 2 years ago
As I mentioned in https://github.com/JuliaGPU/CUDA.jl/issues/1482, the problem is more fundamental than scan!
not supporting some containers: We generally don't support nested CuArray
s, because it is not obvious how to represent them on the device. Creating a CuArray
allocates device memory to represent the elements, so CuArray{<:CuArray}
would allocate memory to represent CuArray
elements on the device, but that cannot work as CuArray
is a host-side abstraction that does not hold any meaning on the device. Working with CuArray{<:CuDeviceArray}
objects already works, but then the elements aren't usable on the host.
So what I meant with https://github.com/JuliaGPU/CUDA.jl/issues/1482#issuecomment-1112119930= is that we first need to figure these questions out, e.g. by looking at what other frameworks do, before improving scan!
. Or don't other frameworks support this either, and is associative_scan
a special abstraction to handle this case?
I think jax.lax.associative_scan
also does not support nested jax arrays; In the example above, I passed (As, Bs)
to scan, so a Python tuple
of jaxlib.xla_extension.DeviceArray
s.
The main difference to CUDA.scan!
seems to be that jax.lax.associative_scan
allows for passing not one, but many arrays. The actual scan algorithm in jax is implemented here, as scan(elems)
, where elems is now a python list of jax arrays. Most of the surrounding code is just there to handle not just lists of arrays, but more structured "pytrees" of arrays, by flattening and unflattening as necessary.
So the question for CUDA.jl seems to be: Could we adjust or extend scan!
to handle lists or tuples of CuArrays?
EDIT: To prevent confusion regarding the "lists or tuples of CuArrays": I mean support for, for example, a tuple of an NxAxB
tensor and a NxC
tensor, so that the iteration is along a specified axis in each of the CuArrays (here the first with N
elements), not along the outermost list or tuple.
Ah yes, that seems like a viable & valuable addition to CUDA.scan!
then (and presumably also Base.accumulate!
?). That said, I won't have the time to look at this anytime soon, so feel free to take a stab at this yourself. Happy to help with any questions you (or anybody else) would have!
i ran into a similar issue here as the recent buzz of state space models has been taking over the sequence modeling world in deep learning, requiring this algorithm. I sketched something out that demonstrates the work inefficient version of the parallel scan to operate over a l x l x h
matrix to do the cumulative matrix multiplication across h
:
using CUDA
using LinearAlgebra
using NNlib
using BenchmarkTools
CUDA.allowscalar(false)
function create_normalized_data()
x = CUDA.rand(3, 3, 50000)
for i in 1:size(x, 3)
x[:, :, i] /= norm(x[:, :, i])
end
return x
end
function operator(a, b)
return batched_mul(a,b)
end
function sequential_all_prefix_sum(operator, x)
n = size(x, 3)
y = copy(x)
for i in 2:n
y[:, :, i] = y[:, :, i-1]*x[:, :, i]
end
return y
end
function recursive_parallel_scan(operator, x::AbstractArray, depth::Int, max_depth::Int)
depth > max_depth && return x
n = size(x, 3)
stride = 2^(depth - 1)
r_ = operator(view(x, :, :, 1:n-stride), view(x, :, :, stride+1:n))
x = cat(view(x, :, :, 1:stride), r_; dims = 3)
return recursive_parallel_scan(operator, x, depth + 1, max_depth)
end
function recursive_parallel_scan(operator, x::AbstractArray)
j_max = ceil(Int, log2(size(x, 3)))
return recursive_parallel_scan(operator, x, 1, j_max)
end
# Example
x = create_normalized_data();
ans1 = recursive_parallel_scan(operator, x);
ans2 = sequential_all_prefix_sum(operator, x);
isapprox(ans1,ans2) # true
@benchmark CUDA.@sync recursive_parallel_scan($operator, $x)
@benchmark CUDA.@sync sequential_all_prefix_sum($operator, $x)
BenchmarkTools.Trial: 1126 samples with 1 evaluation.
Range (min … max): 3.715 ms … 55.529 ms ┊ GC (min … max): 0.00% … 65.94%
Time (median): 4.086 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.427 ms ± 2.893 ms ┊ GC (mean ± σ): 2.87% ± 4.01%
█▄▃▅▇▆▅▄▅▃▄▄▃▁
███████████████▇▄▄▆▆▄▄▇▆▆▄▆▅▆▁▄▅▆▄▁▅▅▄▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ █
3.72 ms Histogram: log(frequency) by time 8.14 ms <
Memory estimate: 123.97 KiB, allocs estimate: 2910.
BenchmarkTools.Trial: 2 samples with 1 evaluation.
Range (min … max): 3.247 s … 3.293 s ┊ GC (min … max): 3.42% … 2.90%
Time (median): 3.270 s ┊ GC (median): 3.16%
Time (mean ± σ): 3.270 s ± 32.233 ms ┊ GC (mean ± σ): 3.16% ± 0.37%
█ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
3.25 s Histogram: frequency by time 3.29 s <
Memory estimate: 381.92 MiB, allocs estimate: 9501981.
I am trying to implement time-parallel Kalman filters, which can be formulated as a prefix-sum [1]. But the assocative operation is a bit more complicated, so
op(::Array, ::Array)
is not really sufficient. Most people in the field currently use jax and itsassociative_scan
function, which operates on "pytrees"; typically a tuple of arrays. The scan is then run along a specified axis on these arrays.Here is a simple example on how associative scans can be used in Jax:
Is there a good way to replicate this functionality in Julia with CUDA.jl? Or, is there a possibility to extend
CUDA.scan!
to enable more complicated inputs, such as tuples of Arrays?Related: #1482 contains an attempt for running
scan!
on a CuArray of structs, as well as some discussion on the topic.[1] Temporal Parallelization of Bayesian Smoothers, Simo Särkkä, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)