Open Powerfulidot opened 1 month ago
My implementation is to create another class that extends nn.Module
. This class will work as STN.
The purpose of this module is to output the grid_sample
that the backbone has learned in a few epochs.
In my implementations, I create two classes:
LocNet
, which is the localization moduleSTN
, which is the module that will take any localization (in my case, LocNet) when initialized, and returns a grid sample when calling forward pass, thus doing an affine transformation to the input.Following lines are my implementations:
class LocNet(nn.Module):
"""LocNet architecture for Spatial Transformer Layer"""
def __init__(self):
super().__init__()
# self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2))
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(24, 94))
self.conv_l = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
)
self.conv_r = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
)
self.dropout = nn.Dropout2d()
self.fc_1 = nn.Linear(in_features=64 * 7 * 30, out_features=32)
self.fc_2 = nn.Linear(in_features=32, out_features=6)
self.fc_2.weight.data.zero_()
self.fc_2.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): Tensor of shape (N, C, H, W), where:
- N: the number of batch
- C: channel
- H: height (pixel) of the image
- W: width (pixel) of the image
Return:
torch.Tensor of affine matrices shape (N, 2, 3).
"""
x_l = self.avg_pool(input)
x_l = self.conv_l(x_l)
x_r = self.conv_r(input)
xs = torch.cat([x_l, x_r], dim=1)
xs = self.dropout(xs)
xs = xs.flatten(start_dim=1) # Flatten for fully-connected layer
xs = self.fc_1(xs)
xs = torch.tanh(xs) # activation
xs = self.fc_2(xs)
xs = torch.tanh(xs) # activation
theta = xs.view(-1, 2, 3) # transform the shape to (N, 2, 3)
return theta
class SpatialTransformerLayer(nn.Module):
"""Spatial Transformer Layer module
Args:
localization (torch.nn.Module): Module to generate localization.
align_corners (bool):
Whether to align_corners. See https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
and https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
"""
def __init__(self, localization: nn.Module, align_corners: bool = False):
super().__init__()
self.localization = localization
self.align_corners = align_corners
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): Tensor of shape (N, C, H, W), where:
- N: the number of batch
- C: channel
- H: height (pixel) of the image
- W: width (pixel) of the image
Return:
torch.Tensor of grid sample.
"""
theta = self.localization(input)
grid = F.affine_grid(
theta=theta, size=input.shape, align_corners=self.align_corners
)
return F.grid_sample(input, grid=grid, align_corners=self.align_corners)
Then, you can do model subclassing from the LPRNet. Probably as follows:
class LPRNet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# Initialization steps
localization = LocNet()
self.stn = STN(localization=localization)
# Another initialization steps...
def forward(self, x):
xs = self.stn(x)
# do the rest with the backbone and global context
return xs
Keep in mind that in the original paper, STN is initially turned off and then enabled at 5k epochs. You might want to have a simple conditional logic and property in your model class.
If you train your model enough, the result may be similar to the following: (first four rows is the input after augmentation, the rest is transformed by the model)
[!Note] References: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
@risangbaskoro thank you! ill give it a try!
我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?
i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?
@Powerfulidot From the training results, adding STNet should significantly improve the performance of LPRNet for license plate recognition, but attention should be paid to the model training process. More info you can see #96
Model | ARCH | Input Shape | GFLOPs | Model Size (MB) | ChineseLicensePlate Accuracy (%) | Training Data | Testing Data |
---|---|---|---|---|---|---|---|
LPRNet | CONV | (3, 24, 94) | 0.3 | 1.9 | 60.105 | 269,621 | 149,002 |
LPRNet+STNet | CONV | (3, 24, 94) | 0.3 | 2.2 | 72.261 | 269,621 | 149,002 |
@risangbaskoro @zjykzj well, thank both of you for your help, but after applying your methods the issue remains. the loss just refused to go down or dropped really slowly. on the other hand, when using the original lprnet structure with the same training parameters there aint no problem.
我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?
i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?