patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.04k stars 135 forks source link

Add Tensor class #808

Closed mr-raccoon-97 closed 3 weeks ago

mr-raccoon-97 commented 3 weeks ago

Why don't you add a Tensor class?

People using equinox usually come from torch, otherwise they would choose another library, so they expect some object oriented programming.

It's easy, just add in the 'init.py' file, something like (but obviously better):

from jax import numpy

class Tensor(Array):
    ones = numpy.ones
    zeros = numpy.zeros
    zeros_like = numpy.zeros_like
    ones_like = numpy.ones_like
    ...

So you will have constructors without from jax import numpy but from equinox import Tensor, making the imports cleaner. Constructors will be all under the same factory.

x = Tensor.ones(3)

Type hints backward compatible with the Array abstract class.


class MLP(Module):
    layers: Sequential

    def __init__(self, input_size: int, hidden_size: int, output_size: int):    
        self.layers = Sequential([
            Linear(input_size, hidden_size, key=KEY),
            Linear(hidden_size, output_size, key=KEY)
        ])

    def __call__(self, input: Tensor) -> Tensor:
        input = self.layers(input)
        return input    

And let's be honest, no one likes the Array name. It makes no sense and it's just ugly, a Tensor is a very aceptable name for a differentiable multidimensional array data structure, Array just mean, a chunk of numbers, when I see Tensor in my code, I know I'm talking about machine learning, deep learning stuff, but when I see Array I don't know if it is some kind of tutorial about matrices from the uni's first year data structures course.

mr-raccoon-97 commented 3 weeks ago

You even can add type checking to tensors with python generics embbebed in tensors, something like:

Tensor[Float], Tensor[Int]

If you say yes I can open a pull request.

lockwo commented 3 weeks ago

What does this Tensor class actually offer over jax arrays? It seems like a trivial wrapper, and seems philosophically divergent from equinox's idea of operating at the jax level (rather than being a jax wrapper, like flax, or other jax packages).

mr-raccoon-97 commented 3 weeks ago

What does this Tensor class actually offer over jax arrays? It seems like a trivial wrapper, and seems philosophically divergent from equinox's idea of operating at the jax level (rather than being a jax wrapper, like flax, or other jax packages).

The issue with jax Array is that they are meant to be used with functional programming, but equinox is object oriented, it's not even implemented, it's an abstract base class that they made for type hints I think. Also it's difficult to add features to something that is not yours.

But anyway, I think you are right and this is not the place for this, I'm gonna try to create a Tensor data structure and add type checking using jaxtyping with generics. I would like this:

def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]: 

To become something like this:

def matrix_multiply(x: Tensor[Float], y: Tensor[Float]) -> Tensor[Float]:

But, yes, should be a separate library.