BYUCamachoLab / SiPANN

Artifical Neural Networks for use with Quantum Photonics
https://sipann.rtfd.io
MIT License
36 stars 15 forks source link

replace tensorflow by jax #23

Open joamatab opened 2 years ago

joamatab commented 2 years ago

Tensorflow is a heavy dependency >500Mb, it would be nice replacing it by JAX,

JAX also comes with neural networks

@jaspreetj @flaport @sequoiap @SkandanC @AustP

SkandanC commented 2 years ago

I have started working on this. The biggest issue, as I see it, is that SiPANN uses pre-saved tensorflow MetaGraphs to load the NN stuff. I am not sure how to replace that with JAX, but I will figure it out eventually.

flaport commented 2 years ago

I think a better approach is to write the neural network in jax/flax and just save out the weights. Then create a loader/saver to read/write the weights from/to a file.

Saving the whole graph as a binary blob such as a saved MetaGraph is not transparent enough for an open source package in my opinion.

SkandanC commented 2 years ago

Added support for jax.numpy https://github.com/BYUCamachoLab/SiPANN/pull/31