Closed jstac closed 1 year ago
Hi @jstac, I'm trying to understand your code. Do you reshape so that broadcasting happens along each dimension? And the final jnp.sum
is to evaluate the expectation?
I wonder if there is a more readable way to write this. Perhaps np.dot
can be more efficient, but it wouldn't improve readability I guess?
By the way, it seems the wc ratios in the SSY model are not computed correctly in the current version. Newton's method returns negative values and successive approximation returns an array of Inf
. Does this also happen on your machine?
Hi @jstac, I'm trying to understand your code. Do you reshape so that broadcasting happens along each dimension? And the final
jnp.sum
is to evaluate the expectation?
Exactly. I stretch out all arrays to the maximum dimension, for all states for today and tomorrow (eg. 4+4 for SSY), and then sum across the tomorrow axes (e.g. the last 4 for SSY).
I wonder if there is a more readable way to write this. Perhaps
np.dot
can be more efficient, but it wouldn't improve readability I guess?
I wonder if it would.
I wish there was a way to improve readability for vectorized code. I hate writing it. That's why I wrote the loops version, to check. I'm waiting for chatGPT to learn to write vectorized code from my loops....
By the way, it seems the wc ratios in the SSY model are not computed correctly in the current version. Newton's method returns negative values and successive approximation returns an array of
Inf
. Does this also happen on your machine?
No, it's OK on my end. Are you sure you have the latest? Did you use the function test_compute_wc_ratio_ssy
at the defaults?
One more comment @JunnanZ . The lack of readability of the code really gets bad when we switch to GCY. You can see that I've tried to be more "systematic" about construction of the array indices when I stretch across all dimensions for the vectorization step. But I'm clearly doing something wrong --- even though I try to follow the same logic.
No, it's OK on my end. Are you sure you have the latest? Did you use the function test_compute_wc_ratio_ssy at the defaults?
@jstac Yes, I'm on 3b5cb51314be08bd1c883e994eedbb8001e315b3, the latest in the repo. Evaluating
shapes = (10, 10, 10, 10)
w = test_compute_wc_ratio_ssy(shapes, algo="newton")
returns
Beginning iteration
iter = 0, error = 800.3094946759451
iter = 1, error = 0.0
Iteration converged after 2 iterations
TOC: Elapsed: 0:00:1.16
Computed solution in 1.168564796447754 seconds.
and a matrix of negative values. And w = test_compute_wc_ratio_ssy(algo="successive_approx")
returns a matrix of Inf
.
Weird. I've just pulled the latest and I get
test_compute_wc_ratio_ssy(shapes=(10,10,10,10), algo="newton")
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:48: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
h_λ_mc = rouwenhorst(n_h_λ, ρ_λ, s_λ, 0)
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:49: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
h_c_mc = rouwenhorst(n_h_c, ρ_c, s_c, 0)
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:50: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
h_z_mc = rouwenhorst(n_h_z, ρ_z, s_z, 0)
/home/john/gh_synced/papers/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py:63: UserWarning: The API of rouwenhorst has changed from `rouwenhorst(n, ybar, sigma, rho)` to `rouwenhorst(n, r
ho, sigma, mu=0.)`. To find more details please visit: https://github.com/QuantEcon/QuantEcon.py/issues/663.
mc_z = rouwenhorst(n_z, ρ, σ_z, 0)
Beginning iteration
iter = 0, error = 4302.341800771495
iter = 1, error = 4074.9605304521597
iter = 2, error = 112.01772152357796
iter = 3, error = 3.834976201446807
iter = 4, error = 0.019254899159136585
iter = 5, error = 0.0
Iteration converged after 6 iterations
TOC: Elapsed: 0:00:42.44
Computed solution in 42.44135785102844 seconds.
The solution looks reasonable.
Edit: I'm running this at home on my CPU but I can't see why that would matter, given that we're using 64 bit floats...
Found the culprit! I was using an old version of quantecon
, which caused the problem.
Wow, Newton's method significantly speeds up the computation in the discrete model.
However, it seems that it still generates large intermediate arrays? For example, if I run
shapes = (15, 15, 15, 15)
w = test_compute_wc_ratio_ssy(shapes, algo="newton")
it returns
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 395.5KiB
constant allocation: 67.3KiB
maybe_live_out allocation: 395.5KiB
preallocated temp allocation: 19.19GiB
preallocated temp fragmentation: 1.2KiB (0.00%)
total allocation: 19.19GiB
total fragmentation: 70.4KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 19.09GiB
Operator: op_name="jit(q)/jit(main)/while/body/jvp(jit(T_ssy))/mul" source_file="/home/zhangjn/Documents/Economics/projects/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py" source_line=115 deduplicated_name="fusion.15"
XLA Label: fusion
Shape: f64[15,15,15,15,15,15,15,15]
==========================
Buffer 2:
Size: 49.44MiB
Operator: op_name="jit(q)/jit(main)/while/body/jvp(jit(T_ssy))/reduce_sum[axes=(4, 5, 6, 7)]" source_file="/home/zhangjn/Documents/Economics/projects/sdfs_via_autodiff/code/ssy/discrete/ssy_wc_ratio.py" source_line=117
XLA Label: fusion
Shape: f64[15,15,15,15,128]
==========================
Buffer 3:
......
where Buffer 1 is an $n^2$ sized array. Do you also have this issue on your end? @jstac
Hmmm, that's a little disappointing.
How did you generate the buffer assignment debugging?
I'm not sure if there is a better way to trace memory usage, but I just choose a very large shape to trigger overflow. Then it will display the debug message.
I think we have the following as the potential bug:
state_numbers = { 'z' : 0,
'z_π' : 1,
'h_z' : 2,
'h_c' : 3,
'h_zπ' : 4,
'h_λ' : 5}
Since we have a strict ordering of the axes by the above-given order, all the matrices used in the code should follow the same order, otherwise, it will fail. You can verify this on the branch https://github.com/jstac/sdfs_via_autodiff/tree/try1.
Once you run the gyc code on the above branch, you will see the two different outputs for z_Q
.
line 363 z_Q (2, 2, 2, 2, 2) [0.9915 0.0085]
line 366 z_Q (2, 2, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1) [0.0085 0.9915]
It's because the z_Q
represents: [i_z_π, i_h_z, i_h_zπ, i_z, j_z]
, whereas after reshaping we are transforming it to the form (n_z, n_z_π, n_h_z, n_h_c, n_h_zπ, n_h_λ)
and so that's where the bug lies. If z_Q
could somehow be represented in the form of [i_z, i_z_π, i_h_z, i_h_zπ, j_z]
, it would fix the bug. The same applies to all the matrices used in the code.
Why this didn't fail in scy?
It's because we are transforming n_h_z, n_z, n_z
to n_h_λ, n_h_c, n_h_z, n_z
where it has followed the order (n_h_z
then n_z
)
I might no be good that explaining the bug in this comment because it's a trickier one and needs some visual explanation. I would be happy to discuss this over a call @jstac.
I tried fixing the above bug and worked for me for shape=(2,2,2,2,2,2)
which was failing earlier. I'll investigate more for shape=(2,3,4,5,6,7)
which is failing.
Great detective work @Smit-create ! I had in the back of my mind it must be some issue like the one you explained, but couldn't find the time to study it carefully. Please let me know how you go with the shape=(2,3,4,5,6,7)
case.
@jstac I got everything running for all shape cases. You can try out the branch try1
. It works all okay with numpy. I'll create a new branch soon with jax
and clear code (free from debug comments) and it will be good to go.
I'm stuck on some code that isn't working as I hoped. The issue is NumPy / JAX style broadcasting.
The code is divided into ssy and gcy, which are two alternative asset pricing models:
https://github.com/jstac/sdfs_via_autodiff/tree/main/code
Broadcasting is working on the ssy side because this function returns True:
https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/discrete/ssy_wc_ratio.py#L172
(The functions return true if the operation that uses broadcasting produces the same result as the operation that uses loops.)
It's failing on the gcy side because this function returns False:
https://github.com/jstac/sdfs_via_autodiff/blob/main/code/gcy/discrete/gcy_wc_ratio.py#L306
All help is appreciated.