DeokyunKim / Progressive-Face-Super-Resolution

Official Pytorch Implementation of Progressive Face Super-Resolution (BMVC 2019 Accepted)
260 stars 60 forks source link

I write a demo for someone who want to load compressed FAN #21

Open ALLinLLM opened 3 years ago

ALLinLLM commented 3 years ago

First, check there is a checkpoints/compressed_model_011000.pth in your workdir, then create a empty folder fan_model, and create a new python file there fan_model/__init__.py

# fan_model/__init__.py
import torch.nn as nn

class fan_squeeze(nn.Module):
    def __init__(self,):
        super(fan_squeeze, self).__init__()
        self.conv = nn.Conv2d(3, 32, 5, 1, 2)
        self.layers = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 68, 3, 1, 1),
            nn.LeakyReLU(),
            )

    def forward(self, x):
        out = self.conv(x)
        out = self.layers(out)
        return out

Finally, add a new python file just in the workdir

# load_compressed_FAN.py
import torch
from PIL import Image
import torchvision.transforms as transforms

# prepare input image
image = Image.open("./figure/eval_target_image.jpeg").convert('RGB')
totensor = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ])
x = totensor(image).unsqueeze(0)
x = x.cuda()
# load saved model with both structure and weights
# NOTICE: there must be ./fan_model/__init__.py which contains 
#         class fan_squeeze, Otherwise torch.load will fail
#         but no need to import here, I don't know why
model = torch.load('./checkpoints/compressed_model_011000.pth')
model = model.cuda()
model.eval()
# run forward
with torch.no_grad():
    print("x shape:", x.shape)
    heatmap_68 = model(x)
    print("heatmap shape:", heatmap_68.shape)

Now, check the dir structure:

WORKDIR
  |--checkpoints
  |       |__ compressed_model_011000.pth
  |
  |--fan_model
  |       |__ __init__.py
  |
  |__load_compressed_FAN.py

run the load_compressed_FAN.py, it should work

DeokyunKim commented 3 years ago

Thanks, poortuning!