deepglugs / deep_imagen

scripts for running and training imagen-pytorch
38 stars 8 forks source link

About muti-GPU on a big dataset. #10

Open zhaobingbingbing opened 1 year ago

zhaobingbingbing commented 1 year ago

Hi, Thanks for your script. When I train imagen with multi gpu on a subset of laion, about 7M. I found that use 'CUDA_VISIBLE_DEVICES=0,1,2,3,4 python3 imagen.py --train...' is faster than 'accelerate launch imagen.py' Although the utilization of GPU in the first way is lower than the second way. And I found there is a bottleneck in data processing (class ImageLabelDataset of data_generator.py), GPU always needs to wait for data processing. Now, the training speed of all those two ways are too slow. Do you have some advice? Thanks again.

deepglugs commented 1 year ago

I'm also struggling to get multi-gpu to work with reasonable speeds. If there's a bottleneck in the dataloader, you can try to increase the number of workers with --workers or you can preprocess your data. For training the first unet, that would involve resizing and padding the images to 64x64.

If you find any other ways to improve multi-gpu, let me know.

zhaobingbingbing commented 1 year ago

I found most transform method in DMs is like, self.transform = T.Compose([ T.Resize(image_size), T.RandomHorizontalFlip(), T.CenterCrop(image_size), ]) is padding necessary?

deepglugs commented 1 year ago

padding is necessary if the images aren't already square, otherwise they will distort. CenterCrop also achieves this but you will lose data.

deepglugs commented 1 year ago

update: I am working on switching to use webdataset as an optional alternative to ImageLabelDataset. So far I have observed it is much faster, but I haven't gotten it working with multi-gpu yet. Once I do, I'll push the change (or maybe I'll push it sooner since it's a non-default option).

deepglugs commented 1 year ago

Pushed webdataset. Multi-gpu now works fast, although, I'm not sure everything is well. When training unet2, I see loss=0.0, which isn't right. Debugging continues...

zhaobingbingbing commented 1 year ago

When I reduced the dataset from 7M to 100k, the training speed is fast, about 0.5h an epoch, however, it will cost 200h for 7M.

deepglugs commented 1 year ago

Is that with webdatasets or the default?

zhaobingbingbing commented 1 year ago

The default way. The problem seems to be in the data processing. When the dataset is too large, time is used to obtain data for each batch_size, rather than training. If I find a way to improve, I will share it with you.

zhaobingbingbing commented 1 year ago

Hi, for 100k image-txt pairs, I find during the first few epochs, the loss drops significantly(from 0.6-0.02). After 5 epochs, the loss almost won't decline. But the sampling quality still increases. So when should I choose to stop training?

deepglugs commented 1 year ago

loss will go down slowly after a while. This is from one of my longer runs. You can even lower the learning rate and that might help drop loss a bit more (but still very slow). I usually train --lr 1e-4 until loss stops dropping and then use --lr 1e-5. Increasing the batch size has also been known to help drop learning rate. image

zhaobingbingbing commented 1 year ago

My loss is similar to yours. For longer training , the loss is still difficult to reduce. I have tried to reduce the LR. But I found the sampling quality will be better for longer training. So except to loss, what can be used as a standard for convergence?

deepglugs commented 1 year ago

Image sample quality is the best method I know of