AutoResearch / EEG-GAN

Other
19 stars 1 forks source link

DDP is broken #84

Closed chadcwilliams closed 3 months ago

chadcwilliams commented 3 months ago

The DDP is broken now, maybe due to changes you made @whyhardt? For both the AE and GAN.

It skips training and says e.g., GAN training done. It doesn't save anything.

I might have narrowed it down to _ddp_training() (165) in ddp_training but haven't gotten further yet. I can keep looking tonight/tomorrow. `

chadcwilliams commented 3 months ago

In _ddp_training it looks like it forces normalization of the data?

    dataloader = Dataloader(opt['data'],
                            kw_time=opt['kw_timestep'],
                            kw_conditions=opt['conditions'],
                            norm_data=True,
                            kw_channel=opt['kw_channel'])

Is this on purpose? I don't think it is, so I committed the change. I also added missing parameters (e.g., std_data)

Edit: Ah, I am seeing these are all hard coded anyways so it would all match up in the end.

chadcwilliams commented 3 months ago

Okay I got it.

@whyhardt : I removed loading state dict of the optimizers and don't believe this will cause any problems. If you do see a way that it will cause a problem, let's discuss and find out how to re-introduce them.

For the GAN, there were two problems:

  1. it was using a non-existing self.learning rate parameter.
  2. Loading the state dict breaks it

In ddp_training/set_ddp_framework:

Original code:

self.generator_optimizer = torch.optim.Adam(self.generator.parameters(),
                                            lr=self.learning_rate, betas=(self.b1, self.b2))
self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                lr=self.learning_rate, betas=(self.b1, self.b2))

self.generator_optimizer.load_state_dict(g_opt_state)
self.discriminator_optimizer.load_state_dict(d_opt_state)

Modified code:

self.generator_optimizer = torch.optim.Adam(self.generator.parameters(),
                                            lr=self.g_lr, betas=(self.b1, self.b2))
self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                lr=self.d_lr, betas=(self.b1, self.b2))

For the AE there were also two problems:

  1. it was using a non-existing opt['kw_conditions'] in the dataloader
  2. Loading the state dict breaks it

In ddp_training/_ddp_training:

Original code

    dataloader = Dataloader(opt['data'],
                        kw_time=opt['kw_time'],
                        kw_conditions=opt['kw_conditions'],
                        norm_data=opt['norm_data'],
                        std_data=opt['std_data'],
                        diff_data=opt['diff_data'],
                        kw_channel=opt['kw_channel'])

Modified Code:

    if isinstance(trainer_ddp, GANDDPTrainer):
        dataloader = Dataloader(opt['data'],
                            kw_time=opt['kw_time'],
                            kw_conditions=opt['kw_conditions'],
                            norm_data=opt['norm_data'],
                            std_data=opt['std_data'],
                            diff_data=opt['diff_data'],
                            kw_channel=opt['kw_channel'])
    elif isinstance(trainer_ddp, AEDDPTrainer):
        dataloader = Dataloader(opt['data'],
                            kw_time=opt['kw_time'],
                            norm_data=opt['norm_data'],
                            std_data=opt['std_data'],
                            diff_data=opt['diff_data'],
                            kw_channel=opt['kw_channel'])
    else:
        raise ValueError(f"Trainer type {type(trainer_ddp)} not supported.")

In ddp_training/set_ddp_framework:

Original code:

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.optimizer.load_state_dict(opt_state)

Modified code:

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)