Closed mr-raccoon-97 closed 3 weeks ago
You even can add type checking to tensors with python generics embbebed in tensors, something like:
Tensor[Float]
, Tensor[Int]
If you say yes I can open a pull request.
What does this Tensor class actually offer over jax arrays? It seems like a trivial wrapper, and seems philosophically divergent from equinox's idea of operating at the jax level (rather than being a jax wrapper, like flax, or other jax packages).
What does this Tensor class actually offer over jax arrays? It seems like a trivial wrapper, and seems philosophically divergent from equinox's idea of operating at the jax level (rather than being a jax wrapper, like flax, or other jax packages).
The issue with jax Array is that they are meant to be used with functional programming, but equinox is object oriented, it's not even implemented, it's an abstract base class that they made for type hints I think. Also it's difficult to add features to something that is not yours.
But anyway, I think you are right and this is not the place for this, I'm gonna try to create a Tensor data structure and add type checking using jaxtyping with generics. I would like this:
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
To become something like this:
def matrix_multiply(x: Tensor[Float], y: Tensor[Float]) -> Tensor[Float]:
But, yes, should be a separate library.
Why don't you add a Tensor class?
People using equinox usually come from torch, otherwise they would choose another library, so they expect some object oriented programming.
It's easy, just add in the 'init.py' file, something like (but obviously better):
So you will have constructors without
from jax import numpy
butfrom equinox import Tensor
, making the imports cleaner. Constructors will be all under the same factory.Type hints backward compatible with the Array abstract class.
And let's be honest, no one likes the Array name. It makes no sense and it's just ugly, a Tensor is a very aceptable name for a differentiable multidimensional array data structure, Array just mean, a chunk of numbers, when I see Tensor in my code, I know I'm talking about machine learning, deep learning stuff, but when I see Array I don't know if it is some kind of tutorial about matrices from the uni's first year data structures course.