Hi, is it possible to convert a stateful model like below to tflite? I see that it is possible to convert an RNN but i don't understand how it works and I get errors with this:
class EMAModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_in = nn.Linear(8, 4)
self.linear_out = nn.Linear(4, 2)
self.alpha = 0.8
self.register_buffer("x_ema", torch.ones((1, 4)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_in(x)
self.x_ema *= self.alpha
self.x_ema += (1.0 - self.alpha) * x
x = self.linear_out(self.x_ema)
return x
(This is just an example model and that forward() method is not used for training)
Hi!
First, thanks for discovering a major bug in model tracing.
Second, yes. Since v0.16.0, stateful models can be converted, although not fully automatically. There's more than one way to do it, take a look at examples: option 1, option 2, option 3.
Hi, is it possible to convert a stateful model like below to tflite? I see that it is possible to convert an RNN but i don't understand how it works and I get errors with this:
(This is just an example model and that
forward()
method is not used for training)Thanks