brentyi / jaxlie

Rigid transforms + Lie groups in JAX
MIT License
234 stars 15 forks source link

Batch axes slicing #20

Open dongwoonhyun opened 5 months ago

dongwoonhyun commented 5 months ago

Hi Brent,

Thanks for the fantastic library. We're finding it very helpful in our ultrasound imaging research. The updated batch axes handling is great for our use case and solves the issues I was trying to hack around in a fork.

Another feature that would be nice is slicing into the batch axes of a MatrixLieGroup object.

For instance,

import jaxlie as jl
import numpy as onp
a = jl.SE3(onp.random.random((4, 1, 10, 7)))
a_slice = jl.SE3(a.wxyz_xyz[..., :5, :])  # current syntax
a_slice = a[..., :5]  # desired syntax

It's a minor thing, but it would improve our ergonomics a lot (e.g., inspecting the pose of an individual sensor in a large array). I think the following code snippet should add that functionality to MatrixLieGroup:

    def __getitem__(self, key):
        """Allow batch axes slicing using [] indexing operator."""
        # If the key has an ellipsis (i.e. "..."), make sure the parameter dimension is
        # included explicitly.
        if hasattr(key, "__iter__") and any([k == Ellipsis for k in key]):
            key = (*key, slice(None))
        return self.__class__(
            **{f.name: self.__dict__[f.name][key] for f in jdc.fields(self)}
        )

What do you think?

brentyi commented 5 months ago

Hi @dongwoonhyun, thanks for the suggestion! Glad the library is useful.

Sure, I'm happy to support this. Do you have time for a PR? At a glance the implementation you wrote looks good to me, only suggestions are: