sirius-ai / LPRNet_Pytorch

Pytorch Implementation For LPRNet, A High Performance And Lightweight License Plate Recognition Framework.
Apache License 2.0
923 stars 231 forks source link

添加stn的问题 issue about adding stn #95

Open Powerfulidot opened 1 month ago

Powerfulidot commented 1 month ago

我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?

i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?

risangbaskoro commented 1 month ago

My implementation is to create another class that extends nn.Module. This class will work as STN.

The purpose of this module is to output the grid_sample that the backbone has learned in a few epochs.

In my implementations, I create two classes:

Following lines are my implementations:

class LocNet(nn.Module):
    """LocNet architecture for Spatial Transformer Layer"""

    def __init__(self):
        super().__init__()

        # self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2))
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(24, 94))
        self.conv_l = nn.Conv2d(
            in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
        )
        self.conv_r = nn.Conv2d(
            in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
        )

        self.dropout = nn.Dropout2d()

        self.fc_1 = nn.Linear(in_features=64 * 7 * 30, out_features=32)
        self.fc_2 = nn.Linear(in_features=32, out_features=6)

        self.fc_2.weight.data.zero_()
        self.fc_2.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input (torch.Tensor): Tensor of shape (N, C, H, W), where:
                - N: the number of batch
                - C: channel
                - H: height (pixel) of the image
                - W: width (pixel) of the image
        Return:
            torch.Tensor of affine matrices shape (N, 2, 3).
        """
        x_l = self.avg_pool(input)
        x_l = self.conv_l(x_l)

        x_r = self.conv_r(input)

        xs = torch.cat([x_l, x_r], dim=1)
        xs = self.dropout(xs)

        xs = xs.flatten(start_dim=1)  # Flatten for fully-connected layer
        xs = self.fc_1(xs)
        xs = torch.tanh(xs)  # activation
        xs = self.fc_2(xs)
        xs = torch.tanh(xs)  # activation
        theta = xs.view(-1, 2, 3)  # transform the shape to (N, 2, 3)
        return theta
class SpatialTransformerLayer(nn.Module):
    """Spatial Transformer Layer module

    Args:
        localization (torch.nn.Module): Module to generate localization.
        align_corners (bool):
            Whether to align_corners. See https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
              and https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
    """

    def __init__(self, localization: nn.Module, align_corners: bool = False):
        super().__init__()
        self.localization = localization
        self.align_corners = align_corners

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input (torch.Tensor): Tensor of shape (N, C, H, W), where:
                - N: the number of batch
                - C: channel
                - H: height (pixel) of the image
                - W: width (pixel) of the image

        Return:
            torch.Tensor of grid sample.
        """
        theta = self.localization(input)
        grid = F.affine_grid(
            theta=theta, size=input.shape, align_corners=self.align_corners
        )
        return F.grid_sample(input, grid=grid, align_corners=self.align_corners)

Then, you can do model subclassing from the LPRNet. Probably as follows:

class LPRNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        # Initialization steps
        localization = LocNet()
        self.stn = STN(localization=localization)
        # Another initialization steps...

    def forward(self, x):
        xs = self.stn(x)
        # do the rest with the backbone and global context
        return xs

Keep in mind that in the original paper, STN is initially turned off and then enabled at 5k epochs. You might want to have a simple conditional logic and property in your model class.

If you train your model enough, the result may be similar to the following: (first four rows is the input after augmentation, the rest is transformed by the model)

grid_combined 2

[!Note] References: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

Powerfulidot commented 1 month ago

@risangbaskoro thank you! ill give it a try!

zjykzj commented 1 month ago

我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?

i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?

@Powerfulidot From the training results, adding STNet should significantly improve the performance of LPRNet for license plate recognition, but attention should be paid to the model training process. More info you can see #96

Model ARCH Input Shape GFLOPs Model Size (MB) ChineseLicensePlate Accuracy (%) Training Data Testing Data
LPRNet CONV (3, 24, 94) 0.3 1.9 60.105 269,621 149,002
LPRNet+STNet CONV (3, 24, 94) 0.3 2.2 72.261 269,621 149,002
Powerfulidot commented 1 month ago

@risangbaskoro @zjykzj well, thank both of you for your help, but after applying your methods the issue remains. the loss just refused to go down or dropped really slowly. on the other hand, when using the original lprnet structure with the same training parameters there aint no problem.