byungjae89 / SPADE-pytorch

PyTorch implementation of "Sub-Image Anomaly Detection with Deep Pyramid Correspondences"
Apache License 2.0
235 stars 43 forks source link

Advice on GPU memory limit #3

Open MichalZajac opened 3 years ago

MichalZajac commented 3 years ago

Hi,

Do you have a recommendation how to modify the code to process larger datasets on GPU with 11GB of available memory? Thank you

byungjae89 commented 3 years ago

Hello, @MichalZajac

SPADE paper itself does not consider GPU memory consumption since MVTec dataset is a quite small dataset. The main problem of memory consumption is that this algorithm brings all training examples on memory.

I think there are many ways to reduce memory consumption. The simplest way is to use a k-means clustering algorithm to select K representative training examples from N examples. (K < N) If your dataset is too large to run k-means clustering algorithm, incremental clustering algorithms would be possible solution such as agg-var clustering method introduced in another paper. [1]

[1] Cognitively-Inspired Model for Incremental Learning Using a Few Examples. CVPRW'20.

AntixK commented 3 years ago

I had the same issue but then I could easily resolve it by simply unloading the tensors from the GPU once they are processed. For instance, the input tensor x need not be in the GPU after the forward pass. So you can then delete it. Similarly, the pred is not required either. The bulk of the GPU memory goes towards the dict train_outputs. So, before appending the variable v, simply move it to the CPU. Simple optimizations can add up. I could run the model with thousands of data points. Hope this helps