Open joamatab opened 2 years ago
I have started working on this. The biggest issue, as I see it, is that SiPANN uses pre-saved tensorflow MetaGraph
s to load the NN stuff. I am not sure how to replace that with JAX, but I will figure it out eventually.
I think a better approach is to write the neural network in jax/flax and just save out the weights. Then create a loader/saver to read/write the weights from/to a file.
Saving the whole graph as a binary blob such as a saved MetaGraph is not transparent enough for an open source package in my opinion.
Added support for jax.numpy https://github.com/BYUCamachoLab/SiPANN/pull/31
Tensorflow is a heavy dependency >500Mb, it would be nice replacing it by JAX,
JAX also comes with neural networks
@jaspreetj @flaport @sequoiap @SkandanC @AustP