jstac / sdfs_via_autodiff

Computing wealth consumption ratios and stochastic discount factors under smooth recursive utility
1 stars 2 forks source link

Failure of vectorization for the GYC model in the discrete case. #3

Closed jstac closed 1 year ago

jstac commented 1 year ago

I'm stuck on some code that isn't working as I hoped. The issue is NumPy / JAX style broadcasting.

The code is divided into ssy and gcy, which are two alternative asset pricing models:

https://github.com/jstac/sdfs_via_autodiff/tree/main/code

Broadcasting is working on the ssy side because this function returns True:

https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/discrete/ssy_wc_ratio.py#L172

(The functions return true if the operation that uses broadcasting produces the same result as the operation that uses loops.)

It's failing on the gcy side because this function returns False:

https://github.com/jstac/sdfs_via_autodiff/blob/main/code/gcy/discrete/gcy_wc_ratio.py#L306

All help is appreciated.

JunnanZ commented 1 year ago

Hi @jstac, I'm trying to understand your code. Do you reshape so that broadcasting happens along each dimension? And the final jnp.sum is to evaluate the expectation?

I wonder if there is a more readable way to write this. Perhaps np.dot can be more efficient, but it wouldn't improve readability I guess?

By the way, it seems the wc ratios in the SSY model are not computed correctly in the current version. Newton's method returns negative values and successive approximation returns an array of Inf. Does this also happen on your machine?

jstac commented 1 year ago

Hi @jstac, I'm trying to understand your code. Do you reshape so that broadcasting happens along each dimension? And the final jnp.sum is to evaluate the expectation?

Exactly. I stretch out all arrays to the maximum dimension, for all states for today and tomorrow (eg. 4+4 for SSY), and then sum across the tomorrow axes (e.g. the last 4 for SSY).

I wonder if there is a more readable way to write this. Perhaps np.dot can be more efficient, but it wouldn't improve readability I guess?

I wonder if it would.

I wish there was a way to improve readability for vectorized code. I hate writing it. That's why I wrote the loops version, to check. I'm waiting for chatGPT to learn to write vectorized code from my loops....

By the way, it seems the wc ratios in the SSY model are not computed correctly in the current version. Newton's method returns negative values and successive approximation returns an array of Inf. Does this also happen on your machine?

No, it's OK on my end. Are you sure you have the latest? Did you use the function test_compute_wc_ratio_ssy at the defaults?

jstac commented 1 year ago

One more comment @JunnanZ . The lack of readability of the code really gets bad when we switch to GCY. You can see that I've tried to be more "systematic" about construction of the array indices when I stretch across all dimensions for the vectorization step. But I'm clearly doing something wrong --- even though I try to follow the same logic.

JunnanZ commented 1 year ago

No, it's OK on my end. Are you sure you have the latest? Did you use the function test_compute_wc_ratio_ssy at the defaults?

@jstac Yes, I'm on 3b5cb51314be08bd1c883e994eedbb8001e315b3, the latest in the repo. Evaluating

shapes = (10, 10, 10, 10)
w = test_compute_wc_ratio_ssy(shapes, algo="newton")

returns

Beginning iteration

iter = 0, error = 800.3094946759451
iter = 1, error = 0.0
Iteration converged after 2 iterations
TOC: Elapsed: 0:00:1.16
Computed solution in 1.168564796447754 seconds.

and a matrix of negative values. And w = test_compute_wc_ratio_ssy(algo="successive_approx") returns a matrix of Inf.

jstac commented 1 year ago

Weird. I've just pulled the latest and I get

test_compute_wc_ratio_ssy(shapes=(10,10,10,10), algo="newton")                   
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:48: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
  h_λ_mc = rouwenhorst(n_h_λ, ρ_λ, s_λ, 0)                                                     
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:49: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
  h_c_mc = rouwenhorst(n_h_c, ρ_c, s_c, 0)                                                     
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:50: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
  h_z_mc = rouwenhorst(n_h_z, ρ_z, s_z, 0)                                                     
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:63: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
  mc_z = rouwenhorst(n_z, ρ, σ_z, 0)                                                           
Beginning iteration                                                                            

iter = 0, error = 4302.341800771495                                                            
iter = 1, error = 4074.9605304521597                                                           
iter = 2, error = 112.01772152357796                                                           
iter = 3, error = 3.834976201446807                                                            
iter = 4, error = 0.019254899159136585                                                         
iter = 5, error = 0.0                                                                          
Iteration converged after 6 iterations                                                         
TOC: Elapsed: 0:00:42.44                                                                       
Computed solution in 42.44135785102844 seconds.            

The solution looks reasonable.

Edit: I'm running this at home on my CPU but I can't see why that would matter, given that we're using 64 bit floats...

JunnanZ commented 1 year ago

Found the culprit! I was using an old version of quantecon, which caused the problem.

Wow, Newton's method significantly speeds up the computation in the discrete model.

However, it seems that it still generates large intermediate arrays? For example, if I run

shapes = (15, 15, 15, 15)
w = test_compute_wc_ratio_ssy(shapes, algo="newton")

it returns

BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   395.5KiB
              constant allocation:    67.3KiB
        maybe_live_out allocation:   395.5KiB
     preallocated temp allocation:   19.19GiB
  preallocated temp fragmentation:     1.2KiB (0.00%)
                 total allocation:   19.19GiB
              total fragmentation:    70.4KiB (0.00%)
Peak buffers:
    Buffer 1:
        Size: 19.09GiB
        Operator: op_name="jit(q)/jit(main)/while/body/jvp(jit(T_ssy))/mul" source_file="/home/zhangjn/Documents/Economics/projects/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py" source_line=115 deduplicated_name="fusion.15"
        XLA Label: fusion
        Shape: f64[15,15,15,15,15,15,15,15]
        ==========================

    Buffer 2:
        Size: 49.44MiB
        Operator: op_name="jit(q)/jit(main)/while/body/jvp(jit(T_ssy))/reduce_sum[axes=(4, 5, 6, 7)]" source_file="/home/zhangjn/Documents/Economics/projects/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py" source_line=117
        XLA Label: fusion
        Shape: f64[15,15,15,15,128]
        ==========================

    Buffer 3:
......

where Buffer 1 is an $n^2$ sized array. Do you also have this issue on your end? @jstac

jstac commented 1 year ago

Hmmm, that's a little disappointing.

How did you generate the buffer assignment debugging?

JunnanZ commented 1 year ago

I'm not sure if there is a better way to trace memory usage, but I just choose a very large shape to trigger overflow. Then it will display the debug message.

Smit-create commented 1 year ago

I think we have the following as the potential bug:

state_numbers = { 'z'      : 0,
                      'z_π'    : 1,
                      'h_z'    : 2,
                      'h_c'    : 3,
                      'h_zπ'   : 4,
                      'h_λ'    : 5}

Since we have a strict ordering of the axes by the above-given order, all the matrices used in the code should follow the same order, otherwise, it will fail. You can verify this on the branch https://github.com/jstac/sdfs_via_autodiff/tree/try1.

Once you run the gyc code on the above branch, you will see the two different outputs for z_Q.

line 363 z_Q (2, 2, 2, 2, 2) [0.9915 0.0085]
line 366 z_Q (2, 2, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1) [0.0085 0.9915]

It's because the z_Q represents: [i_z_π, i_h_z, i_h_zπ, i_z, j_z], whereas after reshaping we are transforming it to the form (n_z, n_z_π, n_h_z, n_h_c, n_h_zπ, n_h_λ) and so that's where the bug lies. If z_Q could somehow be represented in the form of [i_z, i_z_π, i_h_z, i_h_zπ, j_z], it would fix the bug. The same applies to all the matrices used in the code.

Why this didn't fail in scy?

It's because we are transforming n_h_z, n_z, n_z to n_h_λ, n_h_c, n_h_z, n_z where it has followed the order (n_h_z then n_z)

I might no be good that explaining the bug in this comment because it's a trickier one and needs some visual explanation. I would be happy to discuss this over a call @jstac.

Smit-create commented 1 year ago

I tried fixing the above bug and worked for me for shape=(2,2,2,2,2,2) which was failing earlier. I'll investigate more for shape=(2,3,4,5,6,7)which is failing.

jstac commented 1 year ago

Great detective work @Smit-create ! I had in the back of my mind it must be some issue like the one you explained, but couldn't find the time to study it carefully. Please let me know how you go with the shape=(2,3,4,5,6,7) case.

Smit-create commented 1 year ago

@jstac I got everything running for all shape cases. You can try out the branch try1. It works all okay with numpy. I'll create a new branch soon with jax and clear code (free from debug comments) and it will be good to go.