Mentors: Ashish Agarwal, Allen Lavoie, Peng Wang, Dan Moldovan, Paige Bailey, Akshay Naresh Modi and Rohan Jain @ Google Brain
Also special thanks to the Google Brain Researchers Roman Novak and Sam Schoenholz for their insights and assistance.
Motivation: Neural Tangents (Infinite-width NNs) migration and reconstruction for TensorFlow 2.x, originally based on JAX (https://github.com/google/neural-tangents). The basic idea is when the width of the NNs approaches infinity, the dynamics is very similar to a Gaussian Process, which enables better understanding of Deep Learning. We hope with the help of enriched TF ecosystems, this can potentially power more SOTA research in explainable AI and assist in building trustworthy machine learning systems.
Contributions: This is not a pure engineering project. Instead, this is an R\&D project. Neural Tangents itself is a work in progress. Besides, nobody ever tried any migration from JAX to TensorFlow before this project, which also increases the difficulty. Every research project has a chance of failure, but fortunately, after overcoming numerous difficulties, my mentors and I have finished migrating the major APIs from JAX to TensorFlow (Pull Request 1 and Pull Request 2). However, the meaning of this project is far more than a pure migration of Neural Tangents - it is about enriching the TensorFlow NumPy extensions ecosystem (Pull Request 3 and Pull Request 4), about checking the usability of latest nightly version of TensorFlow and about exploring various possibilities on the compatibility and design differences between JAX and TensorFlow (list of issue logs, and an example of TensorFlow improvement, motivated by this migration - https://github.com/google/trax/pull/970).
Timeline: In May, my mentors and I had several discussions about the latest and SOTA papers. In June, I mainly focused on the construction of helper APIs, like tf_jax_stax
, tf_lax
, general convolution, general dot, reduce window;
Starting from the end of June and in early July, I started the official reconstruction and migration of NT from JAX to TF, including setting up the Travis CI and integrating the tests; Then as the hardest part, I started debugging along with the migration.
Acknowledgement: I am really lucky to collaborate with the Google Brain TensorFlow team, and get to be guided by my excellent mentors at TensorFlow. They always provide swift response and a lot of patience in helping me become a better problem solver. In particular, I want to thank Ashish Agarwal, Allen Lavoie, Peng Wang, Dan Moldovan, Paige Bailey, Akshay Naresh Modi and Rohan Jain for their guidance. In this project, I also get to collaborate and learn from the Google Brain Researchers Roman Novak and Sam Schoenholz, and they assisted me in submitting the changes to the Google Neural Tangents repo. Thank you all for an execellent summer!
Future Work: Although the major migration has been done, some tests about some utility files need to be run through and some potential compatibility issues need to be dealt with. I also plan to submit another Pull Request about TF reduce_window
for TensorFlow NumPy extensions. Besides, I will improve the TF NumPy size
API according to the code review in the Pull Request.
We welcome any thoughts and ideas - zhibozhang@cs.toronto.edu
UPDATE: Until now, the major APIs of Neural Tangents program are under TF support. Feel free to
run the example files function_space
, infinite_fcn
and weight_space
Novak, R., Xiao, L., Hron, J., Lee, J., Alemi, A. A., Sohl-Dickstein, J., & Schoenholz, S. S. (2019). Neural tangents: Fast and easy infinite neural networks in python. arXiv preprint arXiv:1912.02803.