JuliaGPU / CUDA.jl

CUDA programming in Julia.
https://juliagpu.org/cuda/
Other
1.2k stars 216 forks source link

Support for more complex associative scans #1492

Open nathanaelbosch opened 2 years ago

nathanaelbosch commented 2 years ago

I am trying to implement time-parallel Kalman filters, which can be formulated as a prefix-sum [1]. But the assocative operation is a bit more complicated, so op(::Array, ::Array) is not really sufficient. Most people in the field currently use jax and its associative_scan function, which operates on "pytrees"; typically a tuple of arrays. The scan is then run along a specified axis on these arrays.

Here is a simple example on how associative scans can be used in Jax:

import jax
import jax.numpy as jnp

def double_matmul(elem1: tuple, elem2: tuple):
    Ai, Bi = elem1
    Aj, Bj = elem2
    return Ai @ Aj, Bi @ Bj

def get_elem(dummy_arg=None, D=2):
    # the dummy arg is just there to easily work with vmap
    A = 0.9 * jnp.eye(D)
    B = 1.1 * jnp.eye(D)
    return A, B

# The input for the associative scan will be a tuple of arrays:
As, Bs = jax.vmap(get_elem)(jnp.ones((100)))
# As.shape == Bs.shape == (10,2,2)

Aout, Bout = jax.lax.associative_scan(double_matmul, (As, Bs))
# Aout.shape == Bout.shape == (10,2,2)

Is there a good way to replicate this functionality in Julia with CUDA.jl? Or, is there a possibility to extend CUDA.scan! to enable more complicated inputs, such as tuples of Arrays?

Related: #1482 contains an attempt for running scan! on a CuArray of structs, as well as some discussion on the topic.

[1] Temporal Parallelization of Bayesian Smoothers, Simo Särkkä, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)

maleadt commented 2 years ago

As I mentioned in https://github.com/JuliaGPU/CUDA.jl/issues/1482, the problem is more fundamental than scan! not supporting some containers: We generally don't support nested CuArrays, because it is not obvious how to represent them on the device. Creating a CuArray allocates device memory to represent the elements, so CuArray{<:CuArray} would allocate memory to represent CuArray elements on the device, but that cannot work as CuArray is a host-side abstraction that does not hold any meaning on the device. Working with CuArray{<:CuDeviceArray} objects already works, but then the elements aren't usable on the host.

So what I meant with https://github.com/JuliaGPU/CUDA.jl/issues/1482#issuecomment-1112119930= is that we first need to figure these questions out, e.g. by looking at what other frameworks do, before improving scan!. Or don't other frameworks support this either, and is associative_scan a special abstraction to handle this case?

nathanaelbosch commented 2 years ago

I think jax.lax.associative_scan also does not support nested jax arrays; In the example above, I passed (As, Bs) to scan, so a Python tuple of jaxlib.xla_extension.DeviceArrays.

The main difference to CUDA.scan! seems to be that jax.lax.associative_scan allows for passing not one, but many arrays. The actual scan algorithm in jax is implemented here, as scan(elems), where elems is now a python list of jax arrays. Most of the surrounding code is just there to handle not just lists of arrays, but more structured "pytrees" of arrays, by flattening and unflattening as necessary.

So the question for CUDA.jl seems to be: Could we adjust or extend scan! to handle lists or tuples of CuArrays?

EDIT: To prevent confusion regarding the "lists or tuples of CuArrays": I mean support for, for example, a tuple of an NxAxB tensor and a NxC tensor, so that the iteration is along a specified axis in each of the CuArrays (here the first with N elements), not along the outermost list or tuple.

maleadt commented 2 years ago

Ah yes, that seems like a viable & valuable addition to CUDA.scan! then (and presumably also Base.accumulate!?). That said, I won't have the time to look at this anytime soon, so feel free to take a stab at this yourself. Happy to help with any questions you (or anybody else) would have!

AnasAbdelR commented 9 months ago

i ran into a similar issue here as the recent buzz of state space models has been taking over the sequence modeling world in deep learning, requiring this algorithm. I sketched something out that demonstrates the work inefficient version of the parallel scan to operate over a l x l x h matrix to do the cumulative matrix multiplication across h:

using CUDA
using LinearAlgebra
using NNlib
using BenchmarkTools

CUDA.allowscalar(false)

function create_normalized_data()
    x = CUDA.rand(3, 3, 50000)  
    for i in 1:size(x, 3)
        x[:, :, i] /= norm(x[:, :, i])
    end
    return x
end

function operator(a, b)
    return batched_mul(a,b)
end

function sequential_all_prefix_sum(operator, x)
    n = size(x, 3)
    y = copy(x) 
    for i in 2:n
        y[:, :, i] = y[:, :, i-1]*x[:, :, i]
    end
    return y
end

function recursive_parallel_scan(operator, x::AbstractArray, depth::Int, max_depth::Int)
    depth > max_depth && return x
    n = size(x, 3)
    stride = 2^(depth - 1)
    r_ = operator(view(x, :, :, 1:n-stride), view(x, :, :, stride+1:n))
    x = cat(view(x, :, :, 1:stride),  r_; dims = 3)
    return recursive_parallel_scan(operator, x, depth + 1, max_depth)
end

function recursive_parallel_scan(operator, x::AbstractArray)
    j_max = ceil(Int, log2(size(x, 3)))
    return recursive_parallel_scan(operator, x, 1, j_max)
end

# Example
x = create_normalized_data(); 
ans1 = recursive_parallel_scan(operator, x);
ans2 = sequential_all_prefix_sum(operator, x);
isapprox(ans1,ans2) # true

@benchmark CUDA.@sync recursive_parallel_scan($operator, $x)
@benchmark CUDA.@sync sequential_all_prefix_sum($operator, $x)

BenchmarkTools.Trial: 1126 samples with 1 evaluation.
 Range (min … max):  3.715 ms … 55.529 ms  ┊ GC (min … max): 0.00% … 65.94%
 Time  (median):     4.086 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.427 ms ±  2.893 ms  ┊ GC (mean ± σ):  2.87% ±  4.01%

  █▄▃▅▇▆▅▄▅▃▄▄▃▁                                              
  ███████████████▇▄▄▆▆▄▄▇▆▆▄▆▅▆▁▄▅▆▄▁▅▅▄▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ █
  3.72 ms      Histogram: log(frequency) by time     8.14 ms <

 Memory estimate: 123.97 KiB, allocs estimate: 2910.

BenchmarkTools.Trial: 2 samples with 1 evaluation.
 Range (min … max):  3.247 s …   3.293 s  ┊ GC (min … max): 3.42% … 2.90%
 Time  (median):     3.270 s              ┊ GC (median):    3.16%
 Time  (mean ± σ):   3.270 s ± 32.233 ms  ┊ GC (mean ± σ):  3.16% ± 0.37%

  █                                                       █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  3.25 s         Histogram: frequency by time        3.29 s <

 Memory estimate: 381.92 MiB, allocs estimate: 9501981.