facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.81k stars 1.32k forks source link

Allow indexing for classes inheriting Transform3d #1801

Closed ListIndexOutOfRange closed 5 months ago

ListIndexOutOfRange commented 5 months ago

Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the Transforms3d class. For instance:

from pytorch3d import transforms

N = 10
r = transforms.random_rotations(N)
T = transforms.Transform3d().rotate(R=r)
R = transforms.Rotate(r)

x = T[0]  # ok
x = R[0]  # TypeError: __init__() got an unexpected keyword argument 'matrix'

This is because all these classes (namely Rotate, Translate, Scale, RotateAxisAngle) inherit the __getitem__() method from Transform3d which has the following code on line 201:

return self.__class__(matrix=self.get_matrix()[index])

The four classes inheriting Transform3d are not initialized through a matrix argument, hence they error. I propose to modify the __getitem__() method of the Transform3d class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the _matrix attribute manually. Thus, instead of

return self.__class__(matrix=self.get_matrix()[index])

I propose to do:

instance = self.__class__.__new__(self.__class__)
instance._matrix = self.get_matrix()[index]
return instance

As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms.

bottler commented 5 months ago

Thanks for the report. Could the offending line be more simply changed from:

self.__class__(matrix=self.get_matrix()[index])

to

Transform3d(matrix=self.get_matrix()[index])

?

The use of __class__ seems to be to enable subclasses to work nicely, but it fails completely.

If indexing a Rotate, Translate, Scale or RotateAxisAngle instance should return another such instance, it would be best to give those classes their own __getitem__s.

ListIndexOutOfRange commented 5 months ago

Ok so I updated the code.

Indeed, I think changing the class by indexing (i.e. R[0] being an instance of Transform3d when R is an instance of Rotate) should be avoided. On the other hand, I feel like it's a bit too much to write a __getitem__ method for every class inheriting Transform3d. As Transforms3d has 4 attributes only, I propose to add the following lines to Transform3d's __getitem method:

for attr in ('_transforms', '_lu', 'device', 'dtype'):
    setattr(instance, attr, getattr(self, attr))
bottler commented 5 months ago

I think I disagree with both these points. Rotate, Scale etc are effectively just alternate ways to initialize a Transform3d, so there's nothing wrong if R[0] returns a Transform3d. And writing a __getitem__ for each class would not be much code.

ListIndexOutOfRange commented 5 months ago

I see your point but

ListIndexOutOfRange commented 5 months ago

Well in the end, I implemented the __getitem__ method for each subclass. Sorry, I didn't understand at first why the other way wasn't working properly. Now, the indexing will return the same type as the indexed class, and behavior should be correct. Let met now what you think !

bottler commented 5 months ago

Implementation is good. Would you be able to add a test case for each? Then I'd be happy to merge.

ListIndexOutOfRange commented 5 months ago

I'm not exactly sure how to do it to be honest, but here is a first attempt. As far as I understand, it is not required to check every type of indexing (int, list slice, ...), but let me know if that's not the case.

bottler commented 5 months ago

Looks good. We basically just need something. This PR is fine, and will hopefully get merged soon. Thank you!

facebook-github-bot commented 5 months ago

@bottler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 5 months ago

@bottler merged this pull request in facebookresearch/pytorch3d@b0462d80799543c6ebec06d156a583f42209e9ff.