AutoResearch / EEG-GAN

Other
19 stars 1 forks source link

AE-Training: cant load checkpoint #83

Closed whyhardt closed 3 months ago

whyhardt commented 3 months ago

Cannot load checkpoint in AE-Training.

I guess it comes from distinguishing between model, model_1 and model_2

It works for AE-GAN-Training though

Error:

RuntimeError: Error(s) in loading state_dict for TransformerDoubleAutoencoder:
        Unexpected key(s) in state_dict: "model_1.encoder.0.weight", "model_1.encoder.0.bias", "model_1.encoder.3.weight", "model_1.encoder.3.bias", "model_1.encoder.6.weight", "model_1.encoder.6.bias", "model_1.encoder.9.weight", "model_1.encoder.9.bias", "model_1.decoder.0.weight", "model_1.decoder.0.bias", "model_1.decoder.3.weight", "model_1.decoder.3.bias", "model_1.decoder.6.weight", "model_1.decoder.6.bias", "model_1.decoder.9.weight", "model_1.decoder.9.bias", ...
chadcwilliams commented 3 months ago

Ah, yeah - so this is something I meant to come back to. Because of the way I refactored the double AE, there's an issue with continued training. Essentially, the double AE trains two nested AEs - first it trains on the timeseries dimension and once that training is fully complete (as in all epochs are complete), it trains on the channels dimension with the timeseries dimension encoded by the first AE. Structurally, it adds the timeseries AE as nested class within the channels AE, named model_1.

I refactored it in this way because I felt when we were trying to train both dimensions at once, the model would do the timeseries training and then the channels training on each batch iteration. Because of this, both dimensions are constantly changing and I thought this was causing training instability.

So- the double AE as is relies on the sequential nature of the timeseries AE -> channels AE. We could try and implement loading it and then continuing training in the same timeseries AE -> channels AE manner. Or we can simply disallow continued training on the double AE (at least for now).

Another issue that comes up from this method that needs attention is that the double AE actually runs training twice, which means it actually saves two separate models - one being the timeseries AE and one being the channels AE with the nested timeseries AE. So, fixing the save_nameparameter might try and save both models with the same name. Really, this kind of solves the problem as the channels AE model would override the timeseries AE model, but we might want to just skip saving the timeseries AE altogether in this special case.

chadcwilliams commented 3 months ago

Fixed on dev branch.