google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
224 stars 26 forks source link

Unable to Convert _sa_block Method of torch.nn.TransformerEncoderLayer #19

Open bisnu-sarkar-inverseai opened 1 month ago

bisnu-sarkar-inverseai commented 1 month ago

Description of the bug:

I'm experiencing an error when trying to convert my PyTorch model to TensorFlow using the ai-edge-torch library. The error occurs when _sa_block is called from torch.nn.TransformerEncoderLayer. Below is the portion of my model causing the issue.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation=F.relu,
        group_norm=0,
        norm_first=False,
        norm_out=False,
        layer_norm_eps=1e-5,
        layer_scale=False,
        init_values=1e-4,
        device=None,
        dtype=None,
        sparse=False,
        mask_type="diag",
        mask_random_seed=42,
        sparse_attn_window=500,
        global_window=50,
        auto_sparsity=False,
        sparsity=0.95,
        batch_first=False,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            batch_first=batch_first,
            norm_first=norm_first,
            device=device,
            dtype=dtype,
        )

        if group_norm:
            self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
            self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)

        self.norm_out = None
        if self.norm_first & norm_out:
            self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
        self.gamma_1 = (
            LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
        )
        self.gamma_2 = (
            LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
        )

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """
        if batch_first = False, src shape is (T, B, C)
        the case where batch_first=True is not covered
        """
        device = src.device
        x = src
        T, B, C = x.shape
        if self.norm_first:
            x=self.norm1(x)
            x = x + self.gamma_1(
                self._sa_block(x, src_mask, src_key_padding_mask)
            )
            x = x + self.gamma_2(self._ff_block(self.norm2(x)))

            if self.norm_out:
                x = self.norm_out(x)
        else:
            x = self.norm1(
                x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
            )
            x = self.norm2(x + self.gamma_2(self._ff_block(x)))

        return x

Actual vs expected behavior:

When calling _sa_block from torch.nn.TransformerEncoderLayer getting issue: torch._dynamo.exc.Unsupported: call_method NNModuleVariable() _sa_block [TensorVariable(), LazyVariableTracker(), LazyVariableTracker()] {}

Any other information you'd like to share?

Thank you for your attention to this matter. I look forward to your response and any guidance you can provide.

advaitjain commented 1 month ago

Please see https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#debugging--reporting-errors for some tips for debugging conversion errors.

In this case you appear to be encoutering an error during torch.export and the model source needs to be modified.

bisnu-sarkar-inverseai commented 1 month ago

Thanks for your reply. I am getting issue during torch.export.export()

talumbau commented 6 days ago

Hi, just to follow up here. Are you certain you need to inherit from TransformerEncoderLayer? That is, do you intend to use NestedTensors? If not, you might be able to accomplish your goal with a more "manual" approach and inheriting from nn.Module instead.

In terms of creating an encoder, please see the examples/t5/t5.py example as an example of authoring and converting a transformer-based encoder. It has T5-specific pieces, obviously, but you should be able to strip it down to something that is exactly like a BERT encoder.