AutoResearch / EEG-GAN

Other
19 stars 1 forks source link

Padding added if timeseries not divisible by patch_size even if ae time_out is #99

Closed chadcwilliams closed 2 months ago

chadcwilliams commented 3 months ago

If the first if statement is False it moves forward, but the second if statement may still be True (although it should be ignored when loading the AE) as it does not incorporate the AE.

Need to change elif opt['gan_type'] == 'tts' and opt['sequence_length'] % opt['patch_size'] != 0: to elif opt['gan_type'] == 'tts' and not ae_dict and opt['sequence_length'] % opt['patch_size'] != 0:

if opt['gan_type'] == 'tts' and ae_dict and (ae_dict['configuration']['target'] == 'full' or ae_dict['configuration']['target'] == 'time') and ae_dict['configuration']['time_out'] % opt['patch_size']!= 0:
        warnings.warn(
            f"Sequence length ({ae_dict['configuration']['timeseries_out']}) must be a multiple of patch size ({default_args['patch_size']}).\n"
            f"The sequence length is padded with zeros to fit the condition.")
        padding = 0
        while (ae_dict['configuration']['timeseries_out'] + padding) % default_args['patch_size'] != 0:
            padding += 1

        padding = torch.zeros((dataset.shape[0], padding, dataset.shape[-1]))
        dataset = torch.cat((dataset, padding), dim=1)
        opt['sequence_length'] = dataset.shape[1] - dataloader.labels.shape[1]
    elif opt['gan_type'] == 'tts' and opt['sequence_length'] % opt['patch_size'] != 0:
        warnings.warn(
            f"Sequence length ({opt['sequence_length']}) must be a multiple of patch size ({default_args['patch_size']}).\n"
            f"The sequence length is padded with zeros to fit the condition.")
        padding = 0
        while (opt['sequence_length'] + padding) % default_args['patch_size'] != 0:
            padding += 1
        padding = torch.zeros((dataset.shape[0], padding, dataset.shape[-1]))
        dataset = torch.cat((dataset, padding), dim=1)
        opt['sequence_length'] = dataset.shape[1] - dataloader.labels.shape[1]
    else:
        padding = torch.zeros((dataset.shape[0], 0, dataset.shape[-1]))
whyhardt commented 2 months ago

I removed automatic padding in my dev-branch. Added too much complexity to the code for only little benefit. Added two error messages for GAN and AE-GAN instead.

See issue https://github.com/AutoResearch/EEG-GAN/issues/98#issue