artyom-beilis / pytorch_dlprim

DLPrimitives/OpenCL out of tree backend for pytorch
http://blog.dlprimitives.org/
MIT License
227 stars 16 forks source link

Model 5MB and batch of data ~50MB takes up 1 GB gpu memory #63

Open sukamenev opened 4 months ago

sukamenev commented 4 months ago

Tested on AMD Fury (Fiji chipset).

artyom-beilis commented 4 months ago

What model? Did you run forward/backward iterations?

sukamenev commented 4 months ago

It's my custom model. 512 neuron (input layer) + 256 * 15 (15 hidden layers by 256 neuron) +2 neuron (output layer). Dictionary has a size ~5MB.

I'm using pytorch_dlprim only for train my model. Batch 25000, every sample have size 512 element of float32. 25000 512 4 = ~ 50MB.

Could such a large memory consumption be due to the size of the batch?

sukamenev commented 4 months ago

For a CPU with this batch size, training occurs as quickly as possible on my data.

And I’m also interested in the question: can pytorch_dlprim be used in parallel to train several neural networks?

I'm currently using 2 scripts, each of which trains one network, and takes up 1 GB of GPU memory each.

I would like more parallel training, since there are about 3500 stream processors in the card.

By training one nn using ~80% gpu power (but sometimes the GPU is idle, perhaps while loading data?) and time 30 sec per epoch.

By training 2 nn ~95-98% gpu and time 45 sec per epoch.

artyom-beilis commented 4 months ago

I don't see a problem to run several processes using same GPU. For faster data loading I'd suggest prepare data for the next batch in parallel. Pytorch works asynchronously.

i.e.

r = net.forward()
l = calc_loss(r,ground_truth)
l.backward()
# ^ finishes fast and run asynchronously
loss_value.item();
# blocks

So if you prepare next batch before "syncing" results you can improve the GPU use significantly

artyom-beilis commented 4 months ago

I'm using pytorch_dlprim only for train my model. Batch 25000, every sample have size 512 element of float32. 25000 512 4 = ~ 50MB.

But what about intermediate values? In order to compute forward/backward computation you need all the intermediate layer values. This is what is main GPU consumption, not the model itself.