Instead of loading the entire dataset to cpu in each process, we only load in process 0(which is responsible for gpu 0) on every node. We save 4x cpu memory on each node. (each node has 4 GPU). During the training, we load all images to GPU0 and use distributed.scatter to send these images to other GPUs.
Instead of transforming images datatype from int8 to fp32 on cpu, we transform them on gpu whenever we need to use them for loss computation. this saves 4x cpu memory.
After changing the PIL image to torch.tensor, we use image.close() to release the PIL image memory. this saves 2x cpu memory.
Implement distributed save 3dgs mode, so that process 0 will not crash when the model is very large.
multiprocesses_image_loading is implemented, but does not have any speed up. I do not know the reason. this is important because 1600*4K images take 10 minutes for loading to CPU.