jstac / sdfs_via_autodiff

Computing wealth consumption ratios and stochastic discount factors under smooth recursive utility
1 stars 2 forks source link

Continuous time results are not robust #6

Open JunnanZ opened 1 year ago

JunnanZ commented 1 year ago

The computed WC ratios are affected by a few factors, including:

I will come back to this issue once I clean up the code.

JunnanZ commented 1 year ago

Hi @Smit-create, maybe you can help me with debugging this issue.

Here is the code for testing:

# First import the relevant functions in solvers.py, ssy_model.py and ssy_wc_ratio_continuous.py

ssy = SSY()
wc_loglinear = wc_loglinear_factory(ssy)

# You can decrease the state size and ram_free, or use another algorithm, so that it runs on your computer
std_devs = 3.0
grids, w_star = wc_ratio_continuous(ssy, h_λ_grid_size=10,
                                    h_c_grid_size=10, h_z_grid_size=10,
                                    z_grid_size=10, num_std_devs=std_devs,
                                    d=5, seed=1234, w_init=None, ram_free=20,
                                    tol=1e-5, method='quadrature',
                                    algorithm="newton", verbose=True,
                                    write_to_file=False)

# Simulate a state path
seed = 1234
T = 1000000
key = jax.random.PRNGKey(seed)
mc_draws = jax.random.normal(key, shape=(4, T))
ssy_params = jnp.array(ssy.params)
x_seq = next_state(ssy_params, jnp.zeros(4), mc_draws)

# Compute the mean and std of wealth-consumption ratios on the path
wc_seq = lin_interp(x_seq, w_star, grids)
print(jnp.array([wc_seq.mean(), wc_seq.std()]))

# Compute the mean and std of log-linear WC ratios on the path
x_seq_np = np.asarray(x_seq)
wc_loglin_seq = np.asarray([np.exp(wc_loglinear(x)) + 1 for x in x_seq_np.T])
print(wc_loglin_seq.mean(), wc_loglin_seq.std())

You will find that as we change std_devs, the simulated moments of our computed WC ratios change as well. This suggests that there is something wrong with our continuous time results, because the simulated moments should not change much when we use a larger/smaller state space to compute and store the WC ratios.

I first thought it's because the grid points we use in each dimension are too sparse to get accurate results. After some experiments, I found that both the number of grid points and the number of quadrature points don't affect the results as much as std_devs.

Maybe you can do some experiments as well based on my code above and try to find the underlying issue.

Smit-create commented 1 year ago

I'm exactly not sure if I made the correct change after looking at the paper but using the following diff for Kg_vmap_quad case :

diff --git a/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py b/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
index d40af76..4cc2d86 100644
--- a/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
+++ b/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
@@ -143,7 +143,7 @@ def Kg_vmap_quad(x, ssy_params, w_vals, grids, nodes, weights):
     pf = jnp.exp(next_x[0] * θ)

     # Interpolate g(next_x) given w_vals:
-    next_g = lin_interp(next_x, w_vals, grids)**θ
+    next_g = lin_interp(next_x, w_vals, grids)

     e_x = jnp.dot(next_g*pf, weights)
     Kg = const * e_x

I got these results:

  1. std_devs = 5.0

% python a.py OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead. batch_size = 5000 Beginning iteration

iter = 0, error = 0.9502661443862497 iter = 1, error = 0.018092205941166783 iter = 2, error = 2.665009771529725e-06 iter = 3, error = 0.0 Iteration converged after 4 iterations [1.95867803e+00 3.61351502e-04] 882.0760183132355 9.01512290707902


2. std_devs=3.0

% python a.py OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead. batch_size = 5000 Beginning iteration

iter = 0, error = 0.945392588063624 iter = 1, error = 0.01784688962278147 iter = 2, error = 3.818055049231717e-06 iter = 3, error = 0.0 Iteration converged after 4 iterations [1.95868279e+00 3.61349493e-04] 882.0760183132355 9.01512290707902


3. std_devs=2.0

% python a.py OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead. batch_size = 5000 Beginning iteration

iter = 0, error = 0.9437685616318792 iter = 1, error = 0.017765545536499117 iter = 2, error = 3.170295667764833e-06 iter = 3, error = 0.0 Iteration converged after 4 iterations [1.95868414e+00 3.61349586e-04] 882.0760183132355 9.01512290707902


4. std_devs=7.0

% python a.py OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead. batch_size = 5000 Beginning iteration

iter = 0, error = 0.9615155386476415 iter = 1, error = 0.018660016556131254 iter = 2, error = 2.880348973999247e-06 iter = 3, error = 0.0 Iteration converged after 4 iterations [1.95867106e+00 3.61355681e-04] 882.0760183132355 9.01512290707902

JunnanZ commented 1 year ago

Hi @Smit-create, the current version in master interpolates the function $w$ instead of $w^\theta$, as can be seen in https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py#L193-L196, so next_g = lin_interp(next_x, w_vals, grids)**θ should be correct.

You can also use the results from the loglinear approximation (882.0760183132355 9.01512290707902) as a benchmark. The simulated moments from our more accurate WC ratios shouldn't be too different from the loglinear ones.

Smit-create commented 1 year ago

Small typo from text.

diff --git a/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py b/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
index d40af76..4786fe6 100644
--- a/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
+++ b/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py
@@ -78,7 +78,7 @@ def next_state(ssy_params, x, η_array):
     h_λ = ρ_λ * h_λ + s_λ * η_array[0]
     h_c = ρ_c * h_c + s_c * η_array[1]
     h_z = ρ_z * h_z + s_z * η_array[2]
-    z = ρ * z + σ_z * η_array[3]
+    z = ρ * z + (1 - ρ**2)**0.5 * σ_z * η_array[3]

     return jnp.array([h_λ, h_c, h_z, z])

Still doesn't make that difference.