Julian-Wyatt / AnoDDPM

CVPR Workshop paper - AnoDDPM: Anomaly Detection with Denoising Diffusion Probabilistic Models using Simplex Noise
https://julianwyatt.co.uk/anoddpm
MIT License
154 stars 27 forks source link

Batch Size Increase #13

Closed WilliamJudge94 closed 1 year ago

WilliamJudge94 commented 1 year ago

Issue When I increase the batch size found in the args28.json file the training errors out.

Expected I would like to increase the batch size in order to speed up the training process. The current memory usage is 4GB on a total of 16GB VRAM. So I should be able to increase the batch size to 3 or 4.

Reproduce change the dataset within args28.json to "cifar" and change the Batch_Size to 4

Error Message image

Julian-Wyatt commented 1 year ago

I may struggle to debug this myself as I no longer have access to a GPU to train this on. But I'll have a short look and get back to you - there's a chance I just wrote a section to just not work with a batch size larger than 1 as I never needed to increase it. (Which is an awful oversight on my part)

Julian-Wyatt commented 1 year ago
simplex_tensor = torch.from_numpy(
            Simplex_instance.rand_3d_fixed_T_octaves(
                x.shape[-2:], t.detach().cpu().numpy()[0:1], octave,
                persistence, frequency
            )).to(x.device)

simplex_tensor = simplex_tensor.view(1, *simplex_tensor.shape) #(or simplex_tensor.unsqueeze(0))

noise[:, i, ...] = simplex_tensor

Can't promise it'll fully work as I can't train it on my machine but the error's no longer there - try updating the generate_simplex_noise method to the above

noticeably, the time value isn't dependent on anything but the inner method needs a single time instance and not a batch wise instance. Also the repeat method has been removed