Closed Smit-create closed 1 year ago
Hi @Smit-create, the direct way is running the first few blocks of https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/ssy_test_continuous.md up until "Tune AA parameters". (You might need to install jupytext to run the markdown file as a Jupyter notebook.) This runs the whole program that computes the wealth consumption ratio under different algorithms, so it might not be suitable for testing purposes.
What I usually do is running this function https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/ssy_wc_ratio_continuous.py#L152 to get a compiled version of the operator $T$. And then I either compare different versions of this operator or I feed it into one of the solvers in https://github.com/jstac/sdfs_via_autodiff/blob/main/code/solvers.py and see how fast it converges.
I would appreciate it if you could also pay some attention to https://github.com/jstac/sdfs_via_autodiff/issues/6 while you experiment on the code. I still haven't pinpointed the reason behind it.
Thanks @JunnanZ. Hmm, I tried testing the notebook and that takes a long time to run on my local machine (Mac M1 Air). Maybe doesn't have a good GPU for this.
@Smit-create Yes, that's why I don't think it's suitable for testing purposes. It helps to choose a smaller grid (smaller zs
, hzs
, hcs
, and hλs
). Also, you can focus on one method for testing, for example, the one that uses Newton's method. It's usually pretty fast for smaller grids.
I see, thank you. That works.
It would be good to write some tests clearly, so it's not necessary to ask how to test the code.
Yes, also is there a test to check the output of loops and vectorised code? It would be nice to have that if that applies to continuous case too same like we have in discrete gcy/ssy case.
Hi @Smit-create, I wrote the vectorized version just for testing. It didn't show any speed or memory use improvement so I ditched it. You can ignore that file for now unless you have an idea to improve it.
As for testing in general, I usually find it enough to just use the existing function as a benchmark. For example, you can do something like this:
T_old = T_fun_factory(params)
T_new = T_fun_factory_new(params) # the new function you are testing
seed = 1234
key = jax.random.PRNGKey(seed)
w0 = jax.random.uniform(key, shape=w_init.shape)
w1_old = T_old(w0)
w1_new = T_new(w0)
jnp.allclose(w1_old, w1_new)
I think it's worth writing tests as functions because it makes it easy for other team members to check that a modification of an algorithm doesn't break it, or see whether a change leads to a speed gain.
Ok, that makes sense. Do you think a function along the lines of the code above (comparing new and old functions) is enough? I'm not sure what's the best way to write tests in this scenario.
Ok, that makes sense. Do you think a function along the lines of the code above (comparing new and old functions) is enough? I'm not sure what's the best way to write tests in this scenario.
Yep, I think that's a good start. We can build from there.
@Smit-create Here is a test function I just wrote. Feel free to make adjustments and incorporate into your code.
def compare_T_factories(T_fact_old, T_fact_new, seed=1234):
"""Compare the results and speed of two function factories for T"""
ssy = SSY()
zs, hzs, hcs, hλs = 3, 4, 5, 6
std_devs = 3.0
ssy_params = jnp.array(ssy.params)
grids = build_grid(ssy, hλs, hcs, hzs, zs, std_devs)
d = 4
nodes, weights = qnwnorm([d, d, d, d])
nodes = jnp.asarray(nodes.T)
weights = jnp.asarray(weights)
state_size = hλs * hcs * hzs * zs
batch_size = state_size
params_quad = ssy_params, grids, nodes, weights
T_old = T_fact_old(params_quad, 'quadrature', batch_size)
T_new = T_fact_new(params_quad, 'quadrature', batch_size)
# Run them once to compile
w0 = jnp.zeros((zs, hzs, hcs, hλs))
T_old(w0)
T_new(w0)
key = jax.random.PRNGKey(seed)
w0 = jax.random.uniform(key, shape=(zs, hzs, hcs, hλs))
t0 = time.time()
w1_old = T_old(w0)
t1 = time.time()
t_old = 1000*(t1 - t0)
t0 = time.time()
w1_new = T_new(w0)
t1 = time.time()
t_new = 1000*(t1 - t0)
print("Speed comparison: {:.4f}ms vs {:.4f}ms".format(t_old, t_new))
print("Same results? {}".format(jnp.allclose(w1_old, w1_new)))
I suggest we get this merged and move on. @Smit-create , could you please add the test above, plus any modifications you think we need?
I haven't made any changes to T_fun_factory
, so it shouldn't affect the correctness of the results. Please review and merge this if this looks good. Thanks!
@jstac @JunnanZ Please review this commit: https://github.com/jstac/sdfs_via_autodiff/pull/7/commits/d020c6ce95c79c2e3031ffd4334b341f2d12064d . I have tried to address https://github.com/jstac/sdfs_via_autodiff/issues/10#issuecomment-1596418269
Hi @Smit-create, I fixed a small error in your latest commit.
Before I was going to merge, I found that b284491f57974704a8a5aeee0cb164ae34811fe2 is not quite right. First of all, np.sqrt(state_size)
is not necessarily an integer, so it will throw an exception. Even if we make that an integer, what you are doing is not the same as my original code. If you agree, please revert that commit and then I will merge.
First of all,
np.sqrt(state_size)
is not necessarily an integer, so it will throw an exception.
Thanks @JunnanZ. I fixed that issue. Also, the algo I used for finding max_divisor
of state_size
is usually faster than the naive approach. I also did some benchmarking and testing of the new results:
import numpy as np
import time
def fac_new(state_size, batch_size):
max_div = 1
for i in range(1, int(np.sqrt(state_size)) + 1):
if state_size % i == 0:
if i <= batch_size:
max_div = max(max_div, i)
z = state_size//i
if z <= batch_size:
max_div = max(max_div, z)
batch_size = max_div
return batch_size
def fac_old(state_size, batch_size):
while (state_size % batch_size > 0):
batch_size -= 1
return batch_size
new_times = []
old_times = []
for i in range(1000):
state_size = int(np.random.uniform(10, 1000000))
batch_size = int(np.random.uniform(1, state_size))
st_time = time.time()
r1 = fac_old(state_size, batch_size)
e_time = time.time()
old_times.append(e_time - st_time)
st_time = time.time()
r2 = fac_new(state_size, batch_size)
e_time = time.time()
new_times.append(e_time - st_time)
assert r1 == r2, f"Failed: {state_size} {batch_size} {r1} {r2}"
print("Old algo time mean:", np.mean(old_times))
print("New algo time mean:", np.mean(new_times))
print("New algo is faster by:", np.mean(old_times)/np.mean(new_times), "times")
This gives me:
% python a.py
Old algo time mean: 0.004881689786911011
New algo time mean: 2.1903276443481445e-05
New algo is faster by: 222.87486529732556 times
Hi @Smit-create, thanks for the improved algo. I wrote that one liner to quickly give me the batch size and didn't think too much about its performance, since it's only run once in the whole program and it only takes a few milliseconds. But your code is indeed faster.
Why the force-push though? Did I do something wrong? Sorry I'm not very familiar with collaborating on Github. I simply pushed a small commit adf9a29 to this branch after your commit https://github.com/jstac/sdfs_via_autodiff/commit/d020c6ce95c79c2e3031ffd4334b341f2d12064d.
Why the force-push though?
Since the main
branch was updated by you by deleting the vectorized file, this PR was having a conflict (it means we can't merge this until we resolve that). So I rebased it on top of the latest main via a force push, and mistakenly forgot to pick your commit https://github.com/jstac/sdfs_via_autodiff/commit/adf9a293a9817101331e8568b5b51d9b52a58f21. So I added it manually again here: https://github.com/jstac/sdfs_via_autodiff/pull/7/commits/c714a5f68dffd8057d1bdadd16c709446c877ddc.
Did I do something wrong? Sorry I'm not very familiar with collaborating on Github.
You did it perfectly. I added that again.
Hi @Smit-create, thanks for the explanation.
Since the main branch was updated by you by deleting the vectorized file, this PR was having a conflict (it means we can't merge this until we resolve that). So I rebased it on top of the latest main via a force push...
I'm not an expert on Git, but on a public branch like this, it might be a better idea to use git merge
in such a scenario?
Anyways, I just tested your changes again and they are good. I'm merging.
it might be a better idea to use git merge in such a scenario?
Well, it depends. Sometimes it's better to use merge commit as it helps in reviewing easily while force push helps to keep the git history clean by removing unnecessary merge commits that maybe a bit painful while reverting a PR.
@JunnanZ @jstac I am trying to make some changes in the continuous ssy case. Is there a way to test my changes? something like I did in the discrete gcy for testing?