Closed veritas9872 closed 4 years ago
Hi! thanks for your contribution!, great first issue!
@veritas9872
This is interesting - currently we already do a kind of prefetching in the loop for our support of iterable datasets, this would just require moving the device handling code and setting the non_blocking
flag. Presumably we would only do this if the use wants (by adding a flag to trainer like device_prefetch
or something) just in case having two batches of data in memory is a problem in some settings. I can take a look at this if needed :)
@williamFalcon Thank you for your positive review.
However, I am not familiar enough with the Pytorch Lightning library to implement this.
I only wished to propose a very simple solution that I believe is being ignored by the deep learning community.
Most people who learn CUDA C/C++ are familiar with cudaMemcpyAsync as an important function for overlapping data transfer with computation.
However, most 2D CNN networks in normal research circumstances only spend a small ratio of total time in data transfer, which is why most researchers do not care that much about asynchronous data transfer.
Researchers in industry do care about this but frankly, I found that methods such as DALI or even tf.data are over-engineered solutions for a problem that only requires a simple fix.
I hope that this suggestion will be implemented in best practice by the wider research community.
@veritas9872 I've looked into this a little bit. It's possible to wrap or subclass any dataloader to preload to a GPU with an interface that's identical to a normal PyTorch dataloader, so no extra changes to lightning would be needed: https://github.com/HenryJia/Lighter/blob/master/lighter/train/loaders.py
However, this would only work for single GPU, as torch.nn.parallel.DataParallel.scatter() already handles data transfer for multiGPU, and it seems to have explicit synchronisation conditions for it.
The main question is whether should this be added as some sort of example code to lightning or not, as it doesn't need any other changes.
@HenryJia Hello. I have seen the source code for your solution and I think that it might be a good idea for single GPU cases.
However, I would like to ask whether you have checked asynchronous transfer actually occurs for your code using NVIDIA Nsight or NVIDIA Visual Profiler.
@williamFalcon Also, I would like to point out that if my solution is used in combination with horovod, then asynchronous transfer of data is possible even for multi-GPU cases. To the best of my understanding, Pytorch Lightning is compatible with horovod.
@veritas9872 I have checked it using a simple example here https://gist.github.com/HenryJia/17e3a647cc2da1dd0ceeb6365bdfeaac
@HenryJia Thank you for the check!
@veritas9872 we already support horovod :) set the backend to horovod.
@veritas9872 I assume that this is solved for now (unless you want to add DALI) but feel free to re-open if needed :rabbit:
@veritas9872 What is the difference between the while-loop and the for-loop ? Why does it allow to prefetch while the for-loop doesn't ? Thank You.
@caiodataopshouse This is because the while loop sends the next batch to GPU asynchronously while the previous batch is calculated. The for loop sends the data to GPU only immediately before the calculation, making it inefficient, especially if the data size is large.
Thank You for the reply @veritas9872. Sorry for my ignorance, but does this difference comes from PyTorch implementation, or from Python itself ? In my mind a block of code would be read in a for-loop and in awhile-loop pretty much the same way.
@caiodataopshouse It is caused by the asynchronous execution model of CUDA. CUDA does not immediately finish the instruction when given the non_blocking=True
and the tensor is already on pinned memory. However, the call for the next tensor to be sent to GPU must be executed while the current tensor is being computed, hence the while loop. Please refer to the CUDA documentation for a more detailed explanation. The TensorFlow documentation also has some useful tips on data pipelining.
Thank You @veritas9872.
🚀 Feature
Copying data from host (CPU) to device (GPU) is a time consuming operation that can cause GPU starvation while the GPU is idle, waiting for data.
On a PCIe 3.0 connection (the most common for GPUs) through the PCI-bus, a 16MB chunk of pinned data moves at approximately 6GB/s. While this is fine for most research operations, this can be a serious bottleneck for industrial use or with large data (such as with 3D volumes).
Also, smaller amounts of data are passed on even slower because of constant overhead.
Motivation
GPU Prefetch is a useful feature that is already in other libraries (e.g. Tensorflow, NVIDIA DALI).
However, these other libraries use graph mode to prefetch their data to GPU.
This is not necessary and a slight adjustment to the Trainer class could allow for prefetching data to GPU without any fancy new library.
This is because Pytorch already has asynchronous GPU transfer available in the .to(non_blocking=True) setting.
All that needs to be done is to unravel the Python for loop into a while loop so that the next mini-batch is sent to GPU(s) while they are busy running the deep learning model.
According to the Pytorch Neurips 2019 paper , Pytorch queues GPU commands while the CPU asynchronously continues with the next piece of host code.
This means that the host-side command to send the data asynchronously to GPU(s) will be given while the GPUs are still running. This allows overlapping of data transfer.
Pitch
A Python for loop is almost always used to iterate over a DataLoader during training.
However, a for loop can be expressed as a while-loop.
Most important, a while loop allows the next mini-batch to be prefetched from CPU to GPU while the current mini-batch is undergoing training on GPU.
Please view on source markdown for a better sense of the code
This is a very general idea and it is not very difficult to implement, though I am not sure if TPUs have asynchronous data transfer.
Alternatives
Integrate NVIDIA DALI for Pytorch into Pytorch Lightning.
Additional context
The code example would prefetch just one mini-batch to GPU while training is going on. It will not send multiple mini-batches.
However, since GPU compute is almost always the bottleneck and CPU bottlenecks can simply be handled by increasing the number of workers, the proposed solution is an adequate one.