mohwald / gandtr

PyTorch implementation of the Dark Side Augmentation
MIT License
7 stars 0 forks source link

Load HEDNGAN generator checkpoint into original CUT #4

Closed BinuxLiu closed 5 months ago

BinuxLiu commented 5 months ago

Hello! @mohwald I successfully tested your trained cyclegan, but could not import Hedgan's checkpoints correctly in the open source CUT project. I use the default parameters in CUT's base_options.py

image

mohwald commented 5 months ago

Hi @BinuxLiu,

HEDNGAN generator was trained with small differences compared to CycleGAN generator:

  1. We use batch normalization layers (instead of instance normalization). This changes the overall architecture and therefore, the state_dict keys do not match anymore.
  2. Set no_antialias=True and no_antialias_up=True. By setting these options, the generator is built with learnable transposed convolutions in the upsampling part.

You should achieve what you need with this snippet:

wget http://ptak.felk.cvut.cz/personal/jenicto2/download/iccv23_gan/hedngan_generator_X.pth
git clone https://github.com/taesungp/contrastive-unpaired-translation cut
import torch
import torch.nn as nn
import sys

sys.path.append("./cut")

from models.networks import ResnetGenerator

generator = ResnetGenerator(
    input_nc=3,
    output_nc=3,
    ngf=64,
    norm_layer=nn.BatchNorm2d,
    use_dropout=False,
    n_blocks=9,
    padding_type='reflect',
    no_antialias=True,
    no_antialias_up=True,
    opt=None
)

ckpt = torch.load("hedngan_generator_X.pth", map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["model_state"])

You can also use our repository to load the checkpoints. By loading the models this way, all hyperparameter changes are already solved for you. Have a look in stages.

BinuxLiu commented 5 months ago

Thank you for your help! I have successfully tested it.

Given that our research interests align, you might be interested in the following experimental results, which were tested on the NightStreets dataset from my previous work. Method Dataset FID L2 PSNR SSIM
CycleGAN Retrieval-SFM 89.52 222.91 6.06 -0.19
HEDNGAN Retrieval-SFM 90.88 212.30 6.44 -0.19
Additionally, I have also replicated the fair comparison experiment we discussed last time, using the model trained on Retrieval-SFM that you mentioned. Method Tokyo 24/7 Night SVOX Night MSLS Night
VGG16-GeM 31.4 2.4 1.8
VGG16-DSA 42.9 8.3 3.6