Closed adam-hartshorne closed 1 year ago
Setting things in x64 bit mode is even more of an issue when using torch2jax_with_vjp,
1.1318159821788687 1.4969821824080587 1.4862239786651517 1.4780943113719986 1.472037333444222 1.4678562092608634 ....
Thanks for submitting this!
Thanks to your code example I was able to catch and fix two errors:
torch2jax
needed to call torch.cuda.synchronize before the torch.call or the data read was not synchronized sometimesMake sure to delete ~/.cache/torch2jax
or call
from torch2jax import compile_and_import_module
compile_and_import_module(force_recompile=True)
at least once after pulling the new changes.
Looks good from my end!
Really appreciate how fast you are on correcting these bugs.
Great, no worries, your example codes make for much better testing than I managed to write by myself, I really appreciate your help.
Find attached a minimum example and the resultant issue. Definitely looks like a synchronisation issue.
Furthermore, in my proper usage which is in x64 mode, I also found as loss / gradients get small, it seems like the optimisation becomes stuck..It appears like there is some sort of issue with numerical precision of gradients (I think that is what is going on with my x64 result shown below).
Running pure JAX
1.1784207 1.0076947 0.8537498 0.7236823 0.6130994 0.5178365 0.4353472 0.36305848 0.2981228 0.23942232 0.18905437 0.14701995 0.11380868 0.088957496 0.07115974 0.05860029 0.050203666 0.044576604 0.040941507 0.03873283 0.037716232 0.03754989 0.03794671 0.038578086 0.039351515
Running torch2jax
1.1784207 1.0076947 0.8537498 0.72368234 0.61309934 0.51783663 0.43534735 0.36305854 0.29812282 0.23942222 0.18905428 0.14701991 0.113808714 0.088957384 1.4932078 <--------------- 0.058600325 0.050203983 0.04457575 1.5029242 <--------------- 0.03873186 0.037716057 0.037549183 0.037946653 0.03857988 1.513658 <---------------
Running same torch_chamfer_distance in pure PyTorch training loop,
1.176 1.020 0.879 0.749 0.632 0.527 0.434 0.353 0.284 0.227 0.181 0.144 0.115 0.092 0.074 0.062 0.052 0.045 0.041 0.037 0.035 0.034 0.033 0.032 0.032