AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
272 stars 17 forks source link

Stateful model #55

Closed johan-sightic closed 4 months ago

johan-sightic commented 4 months ago

Hi, is it possible to convert a stateful model like below to tflite? I see that it is possible to convert an RNN but i don't understand how it works and I get errors with this:

class EMAModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear_in = nn.Linear(8, 4)
        self.linear_out = nn.Linear(4, 2)
        self.alpha = 0.8
        self.register_buffer("x_ema", torch.ones((1, 4)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear_in(x)

        self.x_ema *= self.alpha
        self.x_ema += (1.0 - self.alpha) * x

        x = self.linear_out(self.x_ema)
        return x

(This is just an example model and that forward() method is not used for training)

Thanks

AlexanderLutsenko commented 4 months ago

Hi! First, thanks for discovering a major bug in model tracing. Second, yes. Since v0.16.0, stateful models can be converted, although not fully automatically. There's more than one way to do it, take a look at examples: option 1, option 2, option 3.

johan-sightic commented 4 months ago

Thank you!