JaxGaussianProcesses / JaxKern

Kernel functions in JAX.
MIT License
7 stars 3 forks source link

dev: Use (or draw Inspiration from) the Equinox.Module Class #37

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I notice at the heart base classes are just PyTrees and as such can't store variables for later use i.e. params of kernels have to be passed in for calculations such as gram / cross covariance upon every call.

I wonder if you have considered using or drawing inspiration from Equinox library (https://github.com/patrick-kidger/equinox/), where the core eqx.Module is a special class which registers the class as a PyTree, but does it in such a way that allows one to specify all its fields at the class level (identical to dataclasses).

class MyModule(equinox.Module):
    weight: jax.numpy.ndarray
    bias: jax.numpy.ndarray
    submodule: equinox.Module

    def __init__(self, ....):
        self.weight = weight
        ....

https://docs.kidger.site/equinox/api/module/

This is particularly useful where you can initialise variables, in terms of parameters which you wish to track gradients and static variables. Then one call to eqx.filter(model, ...) returns a PyTree of all trainable parameters stored in the whole model.

In my experience, this structure allows for a much neater and more PyTorch style code, while still retaining the ability to interact with all of JAX seamlessly.

thomaspinder commented 1 year ago

Hi @adam-hartshorne. Thanks for raising this.

Yes, we are looking to move towards exactly this framework. We’ll hope to have something released by the the month’s end, assuming there are no unforseen stumbling blocks...