MhLiao / DB

A PyTorch implementation of "Real-time Scene Text Detection with Differentiable Binarization".
2.08k stars 475 forks source link

How to create lightweight code for a working model? #265

Closed mattroos closed 3 years ago

mattroos commented 3 years ago

This is a general question, not an issue with the code.

I'm trying to write code that straightforwardly builds a model (one with a MobileNetv3 backbone that I trained on ICDAR 2015), gives the same results as demo.py, and can (hopefully) be used to convert the model to ONNX. But the code in this repo is so deeply abstracted into a multitude of classes that I can't figure out how to do it. I'm very close, but...

When I run demo.py on a training image from icdar, I get the prediction below. Looks good. good But when I write code that does nothing more that build the MobileNetv3 and decoder models, and load the weights, I get the results below. I think the features from the backbone are fine. But the decoder output is a checkered pattern (at single-pixel resolution... the larger-scale checkered features in the image below are not "real" but are due to image aliasing). bad It seems like it should be straightforward, but I'm obviously missing something. Any guesses? @MhLiao? The output looks like what one might get if the decoder weights were not loaded properly (or not loaded at all), but I believe I'm doing it correctly. See the ## Build the model section in my code.

Here is my code:

import numpy as np
import torch
import torch.nn as nn
import cv2
from backbones.mobilenetv3 import mobilenet_v3_large
from decoders.seg_detector import SegDetector
import matplotlib.pyplot as plt
plt.ion()

device = torch.device('cuda')
# device = torch.device('cpu')

RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793])
path_img = './datasets/icdar2015/train_images/img_964.jpg'
path_model = './outputs/workspace/DB/SegDetectorModel-seg_detector/mobilenet_v3_large/L1BalanceCELoss/model/model_epoch_936_minibatch_117000'
input_shape = (736, 1312)
# input_shape = (1056, 1888)

class BasicModel(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.backbone = mobilenet_v3_large(pretrained=False)
        self.decoder = SegDetector(in_channels=[24, 40, 112, 960], inner_channels=256, k=50,
                                    bias=False, adaptive=True, smooth=False, serial=False)

    def forward(self, data, *args, **kwargs):
        return self.decoder(self.backbone(data), *args, **kwargs)

## Build the model
model = BasicModel()
model = model.to(device)
states = torch.load(path_model, map_location=device)
model.load_state_dict(states, strict=False)
model.eval()
# model = nn.DataParallel(model)

## Get the image tensor, for testing
img_original = cv2.imread(path_img, cv2.IMREAD_COLOR).astype('float32')
img_tensor = cv2.resize(img_original, input_shape[::-1])
img_tensor = img_tensor - RGB_MEAN
img_tensor /= 255
img_tensor = torch.from_numpy(img_tensor).permute(2, 0, 1).float().unsqueeze(0)
img_tensor = img_tensor.to(device)

## Run the image through the model
with torch.no_grad():
    # pred = model.forward(img_tensor, training=False)
    pred_backbone = model.backbone(img_tensor)  # The output of this is a tuple of 4 tensors, all of which look good
    pred_decoder = model.decoder(pred_backbone, training=False)

## Plot results
plt.figure(1)

plt.subplot(1, 2, 1)
plt.imshow(pred_backbone[0].cpu().numpy()[0, -1, :, :])
plt.axis('off')
plt.title('Example backbone output')

plt.subplot(1, 2, 2)
plt.imshow(pred_decoder.cpu().numpy()[0, -1, :, :])
plt.axis('off')
plt.title('Example decoder output')
mattroos commented 3 years ago

For now, I resolved this by importing the SegDetectorModel class for building the model rather than using the BasicModel class directly. I still want to hone this down to minimum code footprint, but don't quite understand what's going on in SegDetectorModel.__init__(), particularly in the parallelize() call, without which the model doesn't perform as expected.

kurbobo commented 2 years ago

For now, I resolved this by importing the SegDetectorModel class for building the model rather than using the BasicModel class directly. I still want to hone this down to minimum code footprint, but don't quite understand what's going on in SegDetectorModel.__init__(), particularly in the parallelize() call, without which the model doesn't perform as expected.

Hello! Did you understand, what is the strange behaviour with parallelize()? I've encountered similar problem...