samuela / torch2jax

Run PyTorch in JAX. 🤝
168 stars 5 forks source link

Einops support #5

Open jrichterpowell opened 2 weeks ago

jrichterpowell commented 2 weeks ago

Hey guys,

Love this tool!

Extremely useful for me as someone much more comfortable with JAX than torch. I've been using this for a project at the moment and extended some of the api surface coverage (have another issue I might open re: type promotion things), but the application I'm targeting also uses einops extensively. I hacked together preliminary support for this by subclassing the einops array backend as below, then sticking this class at the bottom of the __init__.py for your library. This seems more or less like the correct approach (I need to fix the layers implementation ofc), but I'd be curious to hear feedback if you guys have a better idea to accomplish this. It's almost identical to the pytorch backend from einops, just changed in a few places to use the Torchish class where needed.

If it's acceptable, I'll try to open a PR soon :)

from einops._backends import AbstractBackend
class TorchishBackend(AbstractBackend):
    framework_name = "torch2jax"

    def __init__(self):
        import torch
        import torch2jax

        self.torch = torch
        self.t2j = torch2jax

    def is_appropriate_type(self, tensor):
        return type(tensor).__name__ ==  'Torchish'

    def from_numpy(self, x):
        return self.t2j.Torchish(x)

    def to_numpy(self, x):
        return x.value.cpu().numpy()

    def arange(self, start, stop):
        return self.torch.arange(start, stop)

    def reduce(self, x, operation, reduced_axes):
        if operation == "min":
            return x.amin(dim=reduced_axes)
        elif operation == "max":
            return x.amax(dim=reduced_axes)
        elif operation == "sum":
            return x.sum(dim=reduced_axes)
        elif operation == "mean":
            return x.mean(dim=reduced_axes)
        elif operation in ("any", "all", "prod"):
            # pytorch supports reducing only one operation at a time
            for i in list(sorted(reduced_axes))[::-1]:
                x = getattr(x, operation)(dim=i)
            return x
        else:
            raise NotImplementedError("Unknown reduction ", operation)

    def transpose(self, x, axes):
        return x.permute(*axes)

    def stack_on_zeroth_dimension(self, tensors: list):
        return self.torch.stack(tensors)

    def add_axes(self, x, n_axes, pos2len):
        repeats = [-1] * n_axes
        for axis_position, axis_length in pos2len.items():
            x = self.add_axis(x, axis_position)
            repeats[axis_position] = axis_length
        return x.expand(repeats)

    def tile(self, x, repeats):
        return x.repeat(repeats)

    def concat(self, tensors, axis: int):
        return self.torch.cat(tensors, dim=axis)

    def add_axis(self, x, new_position):
        return self.torch.unsqueeze(x, new_position)

    def is_float_type(self, x):
        return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]

    # def layers(self):
    #     from .layers import torch

    #     return torch

    def einsum(self, pattern, *x):
        return self.torch.einsum(pattern, *x)
samuela commented 2 weeks ago

Hi @jrichterpowell , thanks for putting this together! I would absolutely be open to a PR adding einops support. I'm not familiar with the internals of einops and extending it to other backends, but at a cursory glance this looks reasonable to me.

One question: how does einops know about the existence of TorchishBackend? Does the backend need to be registered somehow?