Closed adam-hartshorne closed 1 year ago
Hi @adam-hartshorne. Thanks for raising this.
Yes, we are looking to move towards exactly this framework. We’ll hope to have something released by the the month’s end, assuming there are no unforseen stumbling blocks...
I notice at the heart base classes are just PyTrees and as such can't store variables for later use i.e. params of kernels have to be passed in for calculations such as gram / cross covariance upon every call.
I wonder if you have considered using or drawing inspiration from Equinox library (https://github.com/patrick-kidger/equinox/), where the core eqx.Module is a special class which registers the class as a PyTree, but does it in such a way that allows one to specify all its fields at the class level (identical to dataclasses).
https://docs.kidger.site/equinox/api/module/
This is particularly useful where you can initialise variables, in terms of parameters which you wish to track gradients and static variables. Then one call to eqx.filter(model, ...) returns a PyTree of all trainable parameters stored in the whole model.
In my experience, this structure allows for a much neater and more PyTorch style code, while still retaining the ability to interact with all of JAX seamlessly.