jstac / sdfs_via_autodiff

Computing wealth consumption ratios and stochastic discount factors under smooth recursive utility
1 stars 2 forks source link

minor code fixes #7

Closed Smit-create closed 1 year ago

Smit-create commented 1 year ago

@JunnanZ @jstac I am trying to make some changes in the continuous ssy case. Is there a way to test my changes? something like I did in the discrete gcy for testing?

JunnanZ commented 1 year ago

Hi @Smit-create, the direct way is running the first few blocks of https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/ssy_test_continuous.md up until "Tune AA parameters". (You might need to install jupytext to run the markdown file as a Jupyter notebook.) This runs the whole program that computes the wealth consumption ratio under different algorithms, so it might not be suitable for testing purposes.

What I usually do is running this function https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py#L152 to get a compiled version of the operator $T$. And then I either compare different versions of this operator or I feed it into one of the solvers in https://github.com/jstac/sdfs_via_autodiff/blob/main/code/solvers.py and see how fast it converges.

I would appreciate it if you could also pay some attention to https://github.com/jstac/sdfs_via_autodiff/issues/6 while you experiment on the code. I still haven't pinpointed the reason behind it.

Smit-create commented 1 year ago

Thanks @JunnanZ. Hmm, I tried testing the notebook and that takes a long time to run on my local machine (Mac M1 Air). Maybe doesn't have a good GPU for this.

JunnanZ commented 1 year ago

@Smit-create Yes, that's why I don't think it's suitable for testing purposes. It helps to choose a smaller grid (smaller zs, hzs, hcs, and hλs). Also, you can focus on one method for testing, for example, the one that uses Newton's method. It's usually pretty fast for smaller grids.

Smit-create commented 1 year ago

I see, thank you. That works.

jstac commented 1 year ago

It would be good to write some tests clearly, so it's not necessary to ask how to test the code.

Smit-create commented 1 year ago

Yes, also is there a test to check the output of loops and vectorised code? It would be nice to have that if that applies to continuous case too same like we have in discrete gcy/ssy case.

JunnanZ commented 1 year ago

Hi @Smit-create, I wrote the vectorized version just for testing. It didn't show any speed or memory use improvement so I ditched it. You can ignore that file for now unless you have an idea to improve it.

As for testing in general, I usually find it enough to just use the existing function as a benchmark. For example, you can do something like this:

T_old = T_fun_factory(params)
T_new = T_fun_factory_new(params)  # the new function you are testing

seed = 1234
key = jax.random.PRNGKey(seed)
w0 = jax.random.uniform(key, shape=w_init.shape)

w1_old = T_old(w0)
w1_new = T_new(w0)

jnp.allclose(w1_old, w1_new)
jstac commented 1 year ago

I think it's worth writing tests as functions because it makes it easy for other team members to check that a modification of an algorithm doesn't break it, or see whether a change leads to a speed gain.

JunnanZ commented 1 year ago

Ok, that makes sense. Do you think a function along the lines of the code above (comparing new and old functions) is enough? I'm not sure what's the best way to write tests in this scenario.

jstac commented 1 year ago

Ok, that makes sense. Do you think a function along the lines of the code above (comparing new and old functions) is enough? I'm not sure what's the best way to write tests in this scenario.

Yep, I think that's a good start. We can build from there.

JunnanZ commented 1 year ago

@Smit-create Here is a test function I just wrote. Feel free to make adjustments and incorporate into your code.

def compare_T_factories(T_fact_old, T_fact_new, seed=1234):
    """Compare the results and speed of two function factories for T"""
    ssy = SSY()
    zs, hzs, hcs, hλs = 3, 4, 5, 6
    std_devs = 3.0

    ssy_params = jnp.array(ssy.params)
    grids = build_grid(ssy, hλs, hcs, hzs, zs, std_devs)

    d = 4
    nodes, weights = qnwnorm([d, d, d, d])
    nodes = jnp.asarray(nodes.T)
    weights = jnp.asarray(weights)

    state_size = hλs * hcs * hzs * zs
    batch_size = state_size

    params_quad = ssy_params, grids, nodes, weights

    T_old = T_fact_old(params_quad, 'quadrature', batch_size)
    T_new = T_fact_new(params_quad, 'quadrature', batch_size)

    # Run them once to compile
    w0 = jnp.zeros((zs, hzs, hcs, hλs))
    T_old(w0)
    T_new(w0)

    key = jax.random.PRNGKey(seed)
    w0 = jax.random.uniform(key, shape=(zs, hzs, hcs, hλs))

    t0 = time.time()
    w1_old = T_old(w0)
    t1 = time.time()
    t_old = 1000*(t1 - t0)

    t0 = time.time()
    w1_new = T_new(w0)
    t1 = time.time()
    t_new = 1000*(t1 - t0)

    print("Speed comparison: {:.4f}ms vs {:.4f}ms".format(t_old, t_new))
    print("Same results? {}".format(jnp.allclose(w1_old, w1_new)))
jstac commented 1 year ago

I suggest we get this merged and move on. @Smit-create , could you please add the test above, plus any modifications you think we need?

Smit-create commented 1 year ago

I haven't made any changes to T_fun_factory, so it shouldn't affect the correctness of the results. Please review and merge this if this looks good. Thanks!

Smit-create commented 1 year ago

@jstac @JunnanZ Please review this commit: https://github.com/jstac/sdfs_via_autodiff/pull/7/commits/d020c6ce95c79c2e3031ffd4334b341f2d12064d . I have tried to address https://github.com/jstac/sdfs_via_autodiff/issues/10#issuecomment-1596418269

JunnanZ commented 1 year ago

Hi @Smit-create, I fixed a small error in your latest commit.

Before I was going to merge, I found that b284491f57974704a8a5aeee0cb164ae34811fe2 is not quite right. First of all, np.sqrt(state_size) is not necessarily an integer, so it will throw an exception. Even if we make that an integer, what you are doing is not the same as my original code. If you agree, please revert that commit and then I will merge.

Smit-create commented 1 year ago

First of all, np.sqrt(state_size) is not necessarily an integer, so it will throw an exception.

Thanks @JunnanZ. I fixed that issue. Also, the algo I used for finding max_divisor of state_size is usually faster than the naive approach. I also did some benchmarking and testing of the new results:

import numpy as np
import time

def fac_new(state_size, batch_size):
    max_div = 1
    for i in range(1, int(np.sqrt(state_size)) + 1):
        if state_size % i == 0:
            if i <= batch_size:
                max_div = max(max_div, i)
            z = state_size//i
            if z <= batch_size:
                max_div = max(max_div, z)
    batch_size = max_div
    return batch_size

def fac_old(state_size, batch_size):
    while (state_size % batch_size > 0):
            batch_size -= 1
    return batch_size

new_times = []
old_times = []

for i in range(1000):
    state_size = int(np.random.uniform(10, 1000000))
    batch_size = int(np.random.uniform(1, state_size))
    st_time = time.time()
    r1 = fac_old(state_size, batch_size)
    e_time = time.time()
    old_times.append(e_time - st_time)
    st_time = time.time()
    r2 = fac_new(state_size, batch_size)
    e_time = time.time()
    new_times.append(e_time - st_time)
    assert r1 == r2, f"Failed: {state_size} {batch_size} {r1} {r2}"

print("Old algo time mean:", np.mean(old_times))
print("New algo time mean:", np.mean(new_times))
print("New algo is faster by:", np.mean(old_times)/np.mean(new_times), "times")

This gives me:

% python a.py
Old algo time mean: 0.004881689786911011
New algo time mean: 2.1903276443481445e-05
New algo is faster by: 222.87486529732556 times
JunnanZ commented 1 year ago

Hi @Smit-create, thanks for the improved algo. I wrote that one liner to quickly give me the batch size and didn't think too much about its performance, since it's only run once in the whole program and it only takes a few milliseconds. But your code is indeed faster.

Why the force-push though? Did I do something wrong? Sorry I'm not very familiar with collaborating on Github. I simply pushed a small commit adf9a29 to this branch after your commit https://github.com/jstac/sdfs_via_autodiff/commit/d020c6ce95c79c2e3031ffd4334b341f2d12064d.

Smit-create commented 1 year ago

Why the force-push though?

Since the main branch was updated by you by deleting the vectorized file, this PR was having a conflict (it means we can't merge this until we resolve that). So I rebased it on top of the latest main via a force push, and mistakenly forgot to pick your commit https://github.com/jstac/sdfs_via_autodiff/commit/adf9a293a9817101331e8568b5b51d9b52a58f21. So I added it manually again here: https://github.com/jstac/sdfs_via_autodiff/pull/7/commits/c714a5f68dffd8057d1bdadd16c709446c877ddc.

Did I do something wrong? Sorry I'm not very familiar with collaborating on Github.

You did it perfectly. I added that again.

JunnanZ commented 1 year ago

Hi @Smit-create, thanks for the explanation.

Since the main branch was updated by you by deleting the vectorized file, this PR was having a conflict (it means we can't merge this until we resolve that). So I rebased it on top of the latest main via a force push...

I'm not an expert on Git, but on a public branch like this, it might be a better idea to use git merge in such a scenario?

Anyways, I just tested your changes again and they are good. I'm merging.

Smit-create commented 1 year ago

it might be a better idea to use git merge in such a scenario?

Well, it depends. Sometimes it's better to use merge commit as it helps in reviewing easily while force push helps to keep the git history clean by removing unnecessary merge commits that maybe a bit painful while reverting a PR.