QuantEcon / book-dp1-public-companion

Dynamic Programming Volume 1
https://quantecon.github.io/book-dp1-public-companion/
BSD 3-Clause "New" or "Revised" License
109 stars 23 forks source link

Shift some Julia code to the GPU #6

Open jstac opened 1 year ago

jstac commented 1 year ago

How does shifting some of the Julia code to the GPU affect timing?

It would be interesting to experiment with putting at least some Julia code on the GPU and testing speed and relative timing of algorithms. While Julia does not have JAX, it's still possible to put array operations on the GPU and have them run efficiently.

See, for example, https://cuda.juliagpu.org/stable/usage/array/

A suggested sequence of steps is

  1. Add a new version of https://github.com/QuantEcon/book-dp1/blob/main/source_code_jl/finite_lq.jl that is all based on array processing operations, similar to https://github.com/QuantEcon/cbc_workshops/blob/main/day_4/investment_jax.ipynb (which is the same model in JAX).
  2. After checking that 1 produces the right results, modify the array operations to execute on the GPU via https://cuda.juliagpu.org/stable/usage/array/ or similar
  3. Test it on Colab via https://colab.research.google.com/github/ageron/julia_notebooks/blob/master/Julia_Colab_Notebook_Template.ipynb or similar

Another alternative is to add a new version of https://github.com/QuantEcon/book-dp1/blob/main/source_code_jl/finite_lq.jl that uses loops rather than arrays and manually parallelize these loops on the GPU using techniques discussed in https://cuda.juliagpu.org/stable/tutorials/introduction/.

jstac commented 1 year ago

@Smit-create , I have assigned you to this only if you are interested. It's quite challenging and not your preferred language.

One reason I think this experiment should be carried out is the relative timings in, say Fig 6.9 and the surrounding discussion will change when the computations are on the GPU. With Python / JAX, I found that HPI becomes be fastest alternativel.

Smit-create commented 1 year ago

@jstac Thanks! I will have a look into this soon and try out this experiment.

Smit-create commented 1 year ago

Hi @jstac I was trying to use CUDA for finitelq.jl and see that it uses array programming but at the same time, it uses scalar indexing for most of the operations (using for loops). Also, see the warning I get:

Warning: Performing scalar indexing on task Task (runnable) @0x00007f074f9fd430.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:90

I then tried using Threads.@threads to parallelize some code as we do in python's numba using parallel=True and prange:

With Threads

```jl using QuantEcon, LinearAlgebra function create_investment_model(; r=0.04, # Interest rate a_0=10.0, a_1=1.0, # Demand parameters γ=25.0, c=1.0, # Adjustment and unit cost y_min=0.0, y_max=20.0, y_size=100, # Grid for output ρ=0.9, ν=1.0, # AR(1) parameters z_size=25) # Grid size for shock β = 1/(1+r) y_grid = LinRange(y_min, y_max, y_size) mc = tauchen(y_size, ρ, ν) z_grid, Q = mc.state_values, mc.p return (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) end """ The aggregator B is given by B(y, z, y′) = r(y, z, y′) + β Σ_z′ v(y′, z′) Q(z, z′)." where r(y, z, y′) := (a_0 - a_1 * y + z - c) y - γ * (y′ - y)^2 """ function B(i, j, k, v, model) (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) = model y, z, y′ = y_grid[i], z_grid[j], y_grid[k] r = (a_0 - a_1 * y + z - c) * y - γ * (y′ - y)^2 return r + β * dot(v[k, :], Q[j, :]) end "Compute a v-greedy policy." function get_greedy(v, model) y_idx, z_idx = (eachindex(g) for g in (model.y_grid, model.z_grid)) σ = Matrix{Int32}(undef, (length(y_idx), length(z_idx))) Threads.@threads for i in y_idx Threads.@threads for j in z_idx @inbounds _, σ[i, j] = findmax(B(i, j, k, v, model) for k in y_idx) end end return σ end "Optimistic policy iteration routine." function optimistic_policy_iteration(model; tol=1e-5, m=100) v = zeros(length(model.y_grid), length(model.z_grid)) error = tol + 1 while error > tol last_v = v σ = get_greedy(v, model) for i in 1:m v = T_σ(v, σ, model) end error = maximum(abs.(v - last_v)) end return get_greedy(v, model) end function check_timing() model = create_investment_model() (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) = model σ_star = optimistic_policy_iteration(model) end ```

Timings:

% export JULIA_NUM_THREADS=8 
% time julia a.jl           
julia a.jl  6.84s user 0.95s system 128% cpu 6.066 total
% time julia a.jl
julia a.jl  7.70s user 0.20s system 130% cpu 6.042 total
% time julia a.jl
julia a.jl  6.69s user 1.27s system 130% cpu 6.095 total
Without Threads

```jl using QuantEcon, LinearAlgebra function create_investment_model(; r=0.04, # Interest rate a_0=10.0, a_1=1.0, # Demand parameters γ=25.0, c=1.0, # Adjustment and unit cost y_min=0.0, y_max=20.0, y_size=100, # Grid for output ρ=0.9, ν=1.0, # AR(1) parameters z_size=25) # Grid size for shock β = 1/(1+r) y_grid = LinRange(y_min, y_max, y_size) mc = tauchen(y_size, ρ, ν) z_grid, Q = mc.state_values, mc.p return (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) end """ The aggregator B is given by B(y, z, y′) = r(y, z, y′) + β Σ_z′ v(y′, z′) Q(z, z′)." where r(y, z, y′) := (a_0 - a_1 * y + z - c) y - γ * (y′ - y)^2 """ function B(i, j, k, v, model) (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) = model y, z, y′ = y_grid[i], z_grid[j], y_grid[k] r = (a_0 - a_1 * y + z - c) * y - γ * (y′ - y)^2 return r + β * dot(v[k, :], Q[j, :]) end function get_greedy(v, model) y_idx, z_idx = (eachindex(g) for g in (model.y_grid, model.z_grid)) σ = Matrix{Int32}(undef, length(y_idx), length(z_idx)) for (i, j) in product(y_idx, z_idx) _, σ[i, j] = findmax(B(i, j, k, v, model) for k in y_idx) end return σ end "Optimistic policy iteration routine." function optimistic_policy_iteration(model; tol=1e-5, m=100) v = zeros(length(model.y_grid), length(model.z_grid)) error = tol + 1 while error > tol last_v = v σ = get_greedy(v, model) for i in 1:m v = T_σ(v, σ, model) end error = maximum(abs.(v - last_v)) end return get_greedy(v, model) end function check_timing() model = create_investment_model() (; β, a_0, a_1, γ, c, y_grid, z_grid, Q) = model σ_star = optimistic_policy_iteration(model) end ```

Timings:

% time julia b.jl
julia b.jl  7.28s user 0.58s system 130% cpu 6.007 total
% time julia b.jl
julia b.jl  7.23s user 0.69s system 130% cpu 6.076 total
jstac commented 1 year ago

Thanks for pushing this forward @Smit-create .

I think the issue here is that we should eliminate for loops and shift to a purely "vectorized" array-processing approach. Please see my comment 1. above. (Notice how the JAX version avoids for loops.) Doing so will allow the Julia compiler to determine optimal parallelization.

I don't expect that Julia will match JAX, because JAX has a very good GPU-aware JIT-compiler.

Smit-create commented 1 year ago

Yeah, I see. I understood the JAX code in python but converting the same to Julia CUDA is a bit difficult which involves some significant differences between NumPy and JAX in handling reshaping and broadcasting. NumPy follows row-major operations while JAX follows column-major operations. And so we can't directly convert this to Julia as it would require some more tweaks.

jstac commented 1 year ago

Thanks for the report @Smit-create . I understand.

Let's put it aside for now. I think we should focus on converting our Python code on https://python.quantecon.org/intro.html over to JAX --- especially where we find we can generate speed gains with fewer lines of code.