Closed trivoldus28 closed 2 years ago
@pattonw I tried to resolve much of your comments... can you take a second look? Thanks!
Documentation is nice and easy to follow! Thanks
Would still like to see tests for the Jax nodes to provide examples on how to use them.
Looks good to me now. I think we're good to merge
Actually 1 last thing, can you add the new nodes to the docs? That can be done in the api.rst file
Merging to v1.3-dev with @pattonw 's approval.
Notes:
Predict
and then used as a parent class for JAXPredict
. This should make using BufferPredict easier for other implementations.Train
was written based off PyTorch implementation. Since JAX is rather low level and allowing many freedom for model implementation, to make things easier for us, we make a simpleGenericJaxModel
interface for models to follow in order to train or predict. A model implementing this interface will need to contain not only the forward model but also loss and update fn. Some examples can be found in https://github.com/funkelab/funlib.learn.jaxPre-Merge Checklist:
vX.Y-dev
)patch-X.Y.Z
)