VITA-Group / DeblurGANv2

[ICCV 2019] "DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better" by Orest Kupyn, Tetiana Martyniuk, Junru Wu, Zhangyang Wang
Other
1k stars 264 forks source link

How can we do transfer learning ? #107

Open rengotj opened 2 years ago

rengotj commented 2 years ago

Hello,

Thank you very much for this really interesting project. I am wondering about transfer learning. Is it possible to specify in the model configuration from which previously trained model the new training should start ?

Thank you for your help. Best regards.

LeDat98 commented 3 months ago

Add a configuration parameter for pre-trained weights in file config/config.yaml pretrained_model_path: "path/to/pretrained/model.h5"

change _init_params function in file train.py to

def _init_params(self):
    self.criterionG, criterionD = get_loss(self.config['model'])
    self.netG, netD = get_nets(self.config['model'])
    self.netG.cuda()  # Move generator to GPU

    # Check if using DataParallel (recommended if multiple GPUs are available)
    if torch.cuda.device_count() > 1:
        self.netG = torch.nn.DataParallel(self.netG)
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")

    pretrained_path = self.config.get('pretrained_model_path', None)
    if pretrained_path and os.path.isfile(pretrained_path):
        checkpoint = torch.load(pretrained_path)
        state_dict = checkpoint['model']

        # If model was trained using DataParallel, its keys will have 'module.' prefix
        if not isinstance(self.netG, torch.nn.DataParallel) and any(k.startswith('module.') for k in state_dict.keys()):
            # Create new state dict without 'module.' prefix
            new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            self.netG.load_state_dict(new_state_dict)
        else:
            self.netG.load_state_dict(state_dict)

        print(f"Loaded pre-trained model from {pretrained_path}")

    self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
    self.model = get_model(self.config['model'])
    self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
    self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
    self.scheduler_G = self._get_scheduler(self.optimizer_G)
    self.scheduler_D = self._get_scheduler(self.optimizer_D)