AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
272 stars 17 forks source link

Added __getattribute__ converter #43

Closed crimson206 closed 6 months ago

crimson206 commented 6 months ago

Motivation

torch.complex module uses __getattribute__. See the example.

complex_tensor = torch.complex(torch.randn(4, 4, 4), torch.randn(4, 4, 4))
# The `torch.Tensor.__getattribute__` is used here.
complex_tensor.imag

Implementation Directory

I made a new file 'nobuco/node_converters/tensor_getattr.py'. To understand the reason, please read the implementation below.

@converter(torch.Tensor.__getattribute__, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_getattr(self: Tensor, name: str, *, out: Optional[Tensor] = None):
    def func(self, name: str, *, out: Optional[Tensor] = None):
        if name == "real":
            return tf.math.real(self)
        elif name == "imag":
            return tf.math.imag(self)
        else:
            raise AttributeError(f"'Tensor' object has no attribute '{name}'")
    return func

There are probably more attributes to convert, and they(torch) will add more. For the additional conversions, we don't need to add new functions, but update the 'if loop'. It would be reasonable to create the dedicated file because of the uniqueness.

Implementation Notebook

Please check the colab Notebook.

crimson206 commented 6 months ago

Revisit the revised colab Notebook to test the implementation.