google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
378 stars 51 forks source link

Error when exporting a model that uses `torch.sum()` #268

Open rishi-menon opened 2 months ago

rishi-menon commented 2 months ago

Description of the bug:

When I use torch.sum() without specifying a dimension, the model does not get exported properly. It gets exported correctly if I do specify the dimension

Actual vs expected behavior:

Minimum working example of the bug:

Code

import os
import torch
import torch.nn as nn
import ai_edge_torch

class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.tensor([2.0]), requires_grad=True)

    def forward(self, x):
        return (self.weights * x).sum()

if __name__ == "__main__":
    os.environ["PJRT_DEVICE"] = "CPU"

    model = SampleModel().eval()
    input_tensor = torch.tensor([1.0, 2.0, 3.0])
    expected_output = model(input_tensor)

    edge_model = ai_edge_torch.convert(model.eval(), (input_tensor,))
    model_output = edge_model(input_tensor)

    print("")
    print(f"Expected output: {expected_output}")
    print(f"Model output   : {model_output}")
    print("")

    edge_model.export("model.tflite")

Output:

Expected output: 12.0
Model output   : [2. 4. 6.]

The expected output should be 12 but the model output is [2,4,6]. The model only performs the multiplication and does not perform the addition.

Model Visualized

The model looks like this when exported: image

Any other information you'd like to share?

Minimum working example when I specify a dimension:

When I specify the dimension when using torch.sum(), the model gets exported correctly.

Code

import os
import torch
import torch.nn as nn
import ai_edge_torch

class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.tensor([2.0]), requires_grad=True)

    def forward(self, x):
        return (self.weights * x).sum(dim=0)

if __name__ == "__main__":
    os.environ["PJRT_DEVICE"] = "CPU"

    model = SampleModel().eval()
    input_tensor = torch.tensor([1.0, 2.0, 3.0])
    expected_output = model(input_tensor)

    edge_model = ai_edge_torch.convert(model.eval(), (input_tensor,))
    model_output = edge_model(input_tensor)

    print("")
    print(f"Expected output: {expected_output}")
    print(f"Model output   : {model_output}")
    print("")

    edge_model.export("model.tflite")

Output:

Expected output: 12.0
Model output   : 12.0

Model Visualized

The model looks like this when exported: image

pkgoogle commented 2 months ago

I was able to reproduce exactly as above.