Closed breakds closed 2 years ago
@breakds FYI, two reference papers I came across a while ago (RL scenario):
https://openreview.net/pdf?id=H1gX8C4YPr https://ai.googleblog.com/2020/03/massively-scaling-reinforcement.html
Although they are proposed for multi-machine training, our multi-gpu single-machine case is a special and simpler case.
Or refer to Pytorch official multi-gpu support (general DL scenario): https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
Thanks @hnyu for the references!
DD-PPO, the first paper's main idea is about early stopping the slow simulation during rollout with a batched environment (potentially distributed over different machines in a cluster), and try to use the full experience from some of the environment and partial experience from the early-stopped ones during the training in each iteration. I think we can borrow the ideas in the near future.
As the first step, I will look into how pytorch's DataParallel
is implemented and use it (or similar techniques) to enable multi-GPU single-machine's
in each of the training iteration.
Currently I am hitting two problems with DataParallel
:
backward()
operation, DataParallel
version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel
is pretty significant.DataParallel
on our network and it will crash. This is another thing that I am working on.
- When tried on a simple
backward()
operation,DataParallel
version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead ofDataParallel
is pretty significant.
I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?
- When tried on a simple
backward()
operation,DataParallel
version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead ofDataParallel
is pretty significant.I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?
Yep I think that is what happened. I was testing the forward
and backward
of a network like this:
class Network(nn.Module):
def __init__(self, input_size, output_size):
super(Network, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 384)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(384, 64)
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(64, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
h = self.fc1(x)
h = self.relu1(h)
h = self.fc2(h)
h = self.relu2(h)
h = self.fc3(h)
h = self.relu3(h)
h = self.fc4(h)
h = self.sigmoid(h)
return h
I realized that this is probably too small because even if a batch of 25600
is passed in, it pretty much does not change in terms of the consumed time.
I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in alf
.
- When tried on a simple
backward()
operation,DataParallel
version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead ofDataParallel
is pretty significant.I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?
Yep I think that is what happened. I was testing the
forward
andbackward
of a network like this:class Network(nn.Module): def __init__(self, input_size, output_size): super(Network, self).__init__() self.fc1 = nn.Linear(input_size, 256) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(256, 384) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(384, 64) self.relu3 = nn.ReLU() self.fc4 = nn.Linear(64, output_size) self.sigmoid = nn.Sigmoid() def forward(self, x): h = self.fc1(x) h = self.relu1(h) h = self.fc2(h) h = self.relu2(h) h = self.fc3(h) h = self.relu3(h) h = self.fc4(h) h = self.sigmoid(h) return h
I realized that this is probably too small because even if a batch of
25600
is passed in, it pretty much does not change in terms of the consumed time.I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in
alf
.
Our expected scenario for multi-gpu is image inputs with a large batch size. So you could try dummy image inputs instead.
Besides running time, also another scenario is to split sgd memory consumption into multiple cards, if one card is not enough.
- When tried on a simple
backward()
operation,DataParallel
version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead ofDataParallel
is pretty significant.I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?
Yep I think that is what happened. I was testing the
forward
andbackward
of a network like this:class Network(nn.Module): def __init__(self, input_size, output_size): super(Network, self).__init__() self.fc1 = nn.Linear(input_size, 256) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(256, 384) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(384, 64) self.relu3 = nn.ReLU() self.fc4 = nn.Linear(64, output_size) self.sigmoid = nn.Sigmoid() def forward(self, x): h = self.fc1(x) h = self.relu1(h) h = self.fc2(h) h = self.relu2(h) h = self.fc3(h) h = self.relu3(h) h = self.fc4(h) h = self.sigmoid(h) return h
I realized that this is probably too small because even if a batch of
25600
is passed in, it pretty much does not change in terms of the consumed time. I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used inalf
.Our expected scenario for multi-gpu is image inputs with a large batch size. So you could try dummy image inputs instead.
Besides running time, also another scenario is to split sgd memory consumption into multiple cards, if one card is not enough.
That makes a lot of sense. Thanks for the suggestions and clarification!
I was using ActorDistributionNetwork
with a batch of random generated images to run the experiment, and got
Traceback (most recent call last):
File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/breakds/projects/alf/alf/bin/experiment/dp_network_experiment.py", line 42, in <module>
action_distribution, actor_state = actor_network(observation, state=())
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
return self.gather(outputs, self.output_device)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 180, in gather
return gather(outputs, output_device, dim=self.dim)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 76, in gather
res = gather_map(outputs)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'Categorical' object is not iterable
With some investigation, I realized that it fails because DataParallel
internally does not know how to combine Categorical
objects, which is the output of ActorDistributionNetwork
. DataParallel
works in two steps
map
)reduce
)And this problem happens at the last step of "Gather". I will use a slightly modified network to continue experiment to work around this.
However, the final solution should make multi-GPU as transparent as possible so that it is convenient to use.
Directly applying DataParallel
may not be the right solution we are looking for partly because of the above issue. This is something to think about later.
After slightly modifying the ActorDistributionNetwork
(for experiment purpose), I was able to run DataParallel
with 2 x 3080:
import torch
import torch.nn as nn
import alf
from alf.networks import ActorDistributionNetwork
from alf.tensor_specs import BoundedTensorSpec
import functools
import time
if __name__ == '__main__':
alf.set_default_device('cuda')
CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1))
actor_network_cls = functools.partial(
ActorDistributionNetwork,
fc_layer_params=(512, ),
conv_layer_params=CONV_LAYER_PARAMS)
actor_network = nn.DataParallel(actor_network_cls(
input_tensor_spec=BoundedTensorSpec(
shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.),
action_spec=BoundedTensorSpec(
shape=(), dtype=torch.int64, minimum=0, maximum=3)))
start_time = time.time()
for i in range(1000):
observation = torch.rand(640, 4, 150, 150)
action_distribution, actor_state = actor_network(observation, state=())
print(f'{time.time() - start_time} seconds elapsed')
I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without DataParallel
:
DataParallel
version took 6s to finish on single card. The DataParallel
version took 1 minutes.This almost rendered DataParallel
not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.
After slightly modifying the
ActorDistributionNetwork
(for experiment purpose), I was able to runDataParallel
with 2 x 3080:import torch import torch.nn as nn import alf from alf.networks import ActorDistributionNetwork from alf.tensor_specs import BoundedTensorSpec import functools import time if __name__ == '__main__': alf.set_default_device('cuda') CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1)) actor_network_cls = functools.partial( ActorDistributionNetwork, fc_layer_params=(512, ), conv_layer_params=CONV_LAYER_PARAMS) actor_network = nn.DataParallel(actor_network_cls( input_tensor_spec=BoundedTensorSpec( shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.), action_spec=BoundedTensorSpec( shape=(), dtype=torch.int64, minimum=0, maximum=3))) start_time = time.time() for i in range(1000): observation = torch.rand(640, 4, 150, 150) action_distribution, actor_state = actor_network(observation, state=()) print(f'{time.time() - start_time} seconds elapsed')
I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without
DataParallel
:
- The memory consumption on both card together is significantly > single card non-data-parallel version
- The non-
DataParallel
version took 6s to finish on single card. TheDataParallel
version took 1 minutes.This almost rendered
DataParallel
not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.
The inefficiency of DataParallel seems unreasonable. There must be something wrong going on.
After slightly modifying the
ActorDistributionNetwork
(for experiment purpose), I was able to runDataParallel
with 2 x 3080:import torch import torch.nn as nn import alf from alf.networks import ActorDistributionNetwork from alf.tensor_specs import BoundedTensorSpec import functools import time if __name__ == '__main__': alf.set_default_device('cuda') CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1)) actor_network_cls = functools.partial( ActorDistributionNetwork, fc_layer_params=(512, ), conv_layer_params=CONV_LAYER_PARAMS) actor_network = nn.DataParallel(actor_network_cls( input_tensor_spec=BoundedTensorSpec( shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.), action_spec=BoundedTensorSpec( shape=(), dtype=torch.int64, minimum=0, maximum=3))) start_time = time.time() for i in range(1000): observation = torch.rand(640, 4, 150, 150) action_distribution, actor_state = actor_network(observation, state=()) print(f'{time.time() - start_time} seconds elapsed')
I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without
DataParallel
:
- The memory consumption on both card together is significantly > single card non-data-parallel version
- The non-
DataParallel
version took 6s to finish on single card. TheDataParallel
version took 1 minutes.This almost rendered
DataParallel
not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.The inefficiency of DataParallel seems unreasonable. There must be something wrong going on.
Or maybe this is by design, I can try to look into where the time is being spent.
According to https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, DataParallel might be even slower than DistributedDataParallel
According to https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, DataParallel might be even slower than DistributedDataParallel
Yep, I can see that GIL issue makes sense. DistributedDataParallel
is even harder to integrate - if we are willing to spend more effort it would probably be better to roll our own solution that suits us better.
@hnyu and I chatted about this today, and I agree with Haonan that we might want to adjust our goal and go for a slightly more complicated (i.e. might require structural update) customized solution. We can chat more about this tomorrow.
ActorDistributionNetwork
. Preliminary result shows about 25%
performance improvement vs the non-parallel version (this is only a single data point, because it is from that 2-GPU machine).ac_breakout
running, but at this moment I need to resolve #930 first.After working around the above issue, I ran into another blocker of distributed data parallel.
On the surface the problem is that the DDP wrapper will get stuck. After debugging with pudb, the problem is that the underlying reducer is throwing exception.
All online discussion about this lead to this pytorch issue, which seems to be a well-known one. To summarize, the current implementation of DDP assumes all parameters in a module to be used. If there are unused parameters, the reducer will fail to communicate their gradients to the other processes.
Document possible next steps (I am a bit reluctant of taking any path yet):
apex
will eventually be merged into pytorch, but not likely any sooner than next version of pytorch.apex
in as a dependency. My fear is that by bringing in another blackbox (which is much less well maintained than pytorch itself) we will be getting more problems rather than features.nccl
. We will have to pay the cost of major design and refactor - totally depend on whether the benefit worth it.Which algorithm are you trying to run with DDP? For some algorithm (e.g DDPG, SAC), there are parameters which are not handled by optimizer and hence there are no gradient for them. For ActorCriticAlgorithm, there shouldn't be such parameters. Algorithm.get_unoptimized_parameter_info() can be used to get all the parameters which are not handled by optimizers.
Thanks @emailweixu !
After discussion with Wei, I realized that the conclusion on "unused parameters" is rather premature. The exception of "incompatible function argument" happens when a C/C++ function is called from python, but the passed argument cannot be converted.
In fact, in the above exception, the function at fault is broadcast_coalesced. It is obvious that the problems happens when we passed a list which cannot be converted to TensorList
for the 2nd argument.
Further tracing the program with debugger confirmed that:
Normally, optimizer is not part of the module and so that it does not get into the module's state_dict
. However, in our case optimizer
is in the state_dict
, so that it gets into the list to be broadcast. However, it is not a Tensor
, so that it breaks the type enforcement here in the C++ function.
Knowing that the solution could be:
state_dict
. I need to think about it to further confirm whether the state of the optimizer needs to be in checkpoint (if so, we probably want them in state_dict
). At this moment my preliminary conclusion is that we need them in state_dict
because if we resume training from the checkpoint, the state in optimizer matters.self.parameters_to_ignore
in DDP. I just need to figure out a way to add optimizer in.One way to solve the problem is to restore the semantics of Module.state_dict to not including optimizer state. And add another pair of functions to Algorithm (e.g. optimizer_state_dict, load_optimizer_state_dict) to handle the state of optimizer. checkpoint_util also needs to be changed accordingly. cc. @Haichao-Zhang since he wrote this part originally.
Knowing that the solution could be:
- The cleaner way might be get optimizers out of the
state_dict
. I need to think about it to further confirm whether the state of the optimizer needs to be in checkpoint (if so, we probably want them instate_dict
). At this moment my preliminary conclusion is that we need them instate_dict
because if we resume training from the checkpoint, the state in optimizer matters.- The other way is to use the
self.parameters_to_ignore
in DDP. I just need to figure out a way to add optimizer in.
Thanks @emailweixu ! I was able to bypass the problem by adding the optimizers
into the ignore list (there is a hidden API of DDP). Just to be safe, I will add a check before DDP wrapper to make sure that all of the stuff in state_dict
is Tesnor, and if not, add all of them into the ignore list.
Thanks @emailweixu ! I was able to bypass the problem by adding the
optimizers
into the ignore list (there is a hidden API of DDP). Just to be safe, I will add a check before DDP wrapper to make sure that all of the stuff instate_dict
is Tesnor, and if not, add all of them into the ignore list.
Seems that you have figured out a solution. Let me know if you need any further discussions on this.
After some more investigation and messing around, I was able to run alf training with DDP. There are a lot of places where we need to get the device number right. Currently I can observe that it is using both GPUs (in terms of memory), but only one of them is actually running (the other seems idle to me). A snapshot of nvidia-smi
showed:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01 Driver Version: 460.73.01 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 GeForce RTX 3080 Off | 00000000:04:00.0 Off | N/A |
| 0% 56C P2 107W / 340W | 3558MiB / 10015MiB | 27% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 GeForce RTX 3080 Off | 00000000:09:00.0 Off | N/A |
| 0% 45C P8 15W / 340W | 1755MiB / 10018MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1610 G ...xorg-server-1.20.11/bin/X 9MiB |
| 0 N/A N/A 1687 G ...hell-40.1/bin/gnome-shell 8MiB |
| 0 N/A N/A 1794885 C ...thon3-3.8.9/bin/python3.8 2217MiB |
| 0 N/A N/A 1794886 C ...thon3-3.8.9/bin/python3.8 1319MiB |
| 1 N/A N/A 1610 G ...xorg-server-1.20.11/bin/X 4MiB |
| 1 N/A N/A 1794886 C ...thon3-3.8.9/bin/python3.8 1747MiB |
+-----------------------------------------------------------------------------+
From the log I can clearly see one of the 2 processes is running and producing log. This is weird because based on my understanding of how DDP works, one process will have to wait for another before it can go on when hitting backward()
. Will keep debugging along this direction.
Turns out that my hypothesis on gpu 0 and gpu 1 run sequentially proved to be wrong. I made 2 mistakes in my toy example with ActionDistributionNetwork
:
backward()
so that the 2 processes are not synchronized except when they are initialized.Good news is that after the fix the toy example
start_time = time.time()
for i in range(2500):
observation = torch.rand(batch_size, 4, 84, 84, device=rank)
action_distribution, actor_state = actor_network(observation, state=())
action = action_distribution.sample()
reward = torch.rand(batch_size, device=rank)
loss = - torch.mean(action_distribution.log_prob(action) * reward)
loss.backward()
if i % 100 == 0:
print(f'iteration {i} - {time.time() - start_time} seconds elapsed on device {rank}')
print(f'{time.time() - start_time} seconds elapsed on device {rank}')
proves that when batch_size
is big enough (in my comparison, batch size is divided by the number of GPUs for multi-gpu run), multi-gpu has some gain in terms of time elapsed:
Single GPU: 31.069265604019165 seconds elapsed on device 0
Double GPU:
24.037397384643555 seconds elapsed on device 1
24.038329362869263 seconds elapsed on device 0
for batch size = 1024 per batch. Bad news is that this does not lead to why running alf.bin.train
leads to only one process progressing. Keep investigating.
I think for this toy example, you can also try more complex CNN architectures like (ResnetEncodingNetwork) to see the gain.
Turns out the problem is still one of the two processes throws an exception, but that exception is not observed.
Actually I researched (experimented) on how exception in nested sub processes works yesterday, and thought I had solved this problem. Sadly there are still certain cases such exception just raised silently. Normally I would expect it to show in terminal because I explicitly catch it and print it in the offending process.
The exception itself looks like this:
File "/home/breakds/projects/alf/alf/utils/tensor_utils.py", line 87, in tensor_extend_zero
return torch.cat((x, torch.zeros(1, *x.shape[1:], dtype=x.dtype)))
RuntimeError: All input tensors must be on the same device. Received cuda:1 and cuda:0
After going DDP, we would like to have newly created tensors to be placed on a process-dependent default rank (a.k.a. device id). I have patched quite a few places with to(rank)
but turns out that there are too many occurrences to cover (because I want to get this working before commitment on making them clean and nice). Will have to go back to research on process-dependent default rank.
torch.cuda.set_device(rank)
seems to do the trick. But the behavior is still only device 0's process progressing. Keep investigating.
I think for this toy example, you can also try more complex CNN architectures like (ResnetEncodingNetwork) to see the gain.
Acknowledged. Will try it later. Thanks!
More updates, with some other small problems fixed, I was now able to train with 2 GPU under DDP wrapper:
INFO:absl:[rank=0] None -> ac_breakout: 79 time=2.391 throughput=107.07
INFO:absl:[rank=1] None -> ac_breakout: 79 time=2.346 throughput=109.14
INFO:absl:[rank=1] None -> ac_breakout: 85 time=0.185 throughput=1385.38
INFO:absl:[rank=0] None -> ac_breakout: 85 time=0.191 throughput=1341.97
INFO:absl:[rank=0] None -> ac_breakout: 89 time=2.386 throughput=107.31
INFO:absl:[rank=1] None -> ac_breakout: 89 time=2.463 throughput=103.93
INFO:absl:[rank=1] None -> ac_breakout: 95 time=0.177 throughput=1444.35
INFO:absl:[rank=0] None -> ac_breakout: 95 time=0.178 throughput=1436.04
INFO:absl:[rank=0] None -> ac_breakout: 99 time=2.333 throughput=109.75
INFO:absl:[rank=1] None -> ac_breakout: 99 time=2.446 throughput=104.67
INFO:absl:[rank=0] None -> ac_breakout: 105 time=0.165 throughput=1550.85
INFO:absl:[rank=1] None -> ac_breakout: 105 time=0.169 throughput=1518.86
INFO:absl:[rank=0] None -> ac_breakout: 109 time=2.363 throughput=108.32
INFO:absl:[rank=1] None -> ac_breakout: 109 time=2.405 throughput=106.46
INFO:absl:[rank=0] None -> ac_breakout: 115 time=0.177 throughput=1446.16
INFO:absl:[rank=1] None -> ac_breakout: 115 time=0.176 throughput=1453.41
The synchronization is there too. The trained result cannot be played yet (which is expected), will take a closer look on the checkpoints. Meanwhile, I will start to think about a cleaner implementation.
Outline for plan of next steps, after discussion on 2021/07/23:
[X] Update =train.py=
[x] How does DDP guarantee that all optimizer.step()
are synchronized?
optimizer.step
in one process instead of having to run optmizer.step()
on multiple processes doing (supposedly) exactly the same work.[x] Figure out how DDP establishes backward callbacks. This proves whether wrapping algorithm is enough.
[x] Only rank 0 process writes checkpoint
[x] Can we still run =play.py=?
[ ] Does replay buffer synchronize over DDP?
[ ] Tensorboard: any metrics changed their meaning?
[x] How is batchnorm affected? Do we have to use the ~SyncBatchNorm~ as mentioned in the source code of DDP?
[x] Debugging and =Ctrl-C= handling
[x] Figuring out why staring environments takes longer in DDP mode
[ ] Figuring out why DDP makes each iteration slower
More updates, with some other small problems fixed, I was now able to train with 2 GPU under DDP wrapper:
INFO:absl:[rank=0] None -> ac_breakout: 79 time=2.391 throughput=107.07 INFO:absl:[rank=1] None -> ac_breakout: 79 time=2.346 throughput=109.14 INFO:absl:[rank=1] None -> ac_breakout: 85 time=0.185 throughput=1385.38 INFO:absl:[rank=0] None -> ac_breakout: 85 time=0.191 throughput=1341.97 INFO:absl:[rank=0] None -> ac_breakout: 89 time=2.386 throughput=107.31 INFO:absl:[rank=1] None -> ac_breakout: 89 time=2.463 throughput=103.93 INFO:absl:[rank=1] None -> ac_breakout: 95 time=0.177 throughput=1444.35 INFO:absl:[rank=0] None -> ac_breakout: 95 time=0.178 throughput=1436.04 INFO:absl:[rank=0] None -> ac_breakout: 99 time=2.333 throughput=109.75 INFO:absl:[rank=1] None -> ac_breakout: 99 time=2.446 throughput=104.67 INFO:absl:[rank=0] None -> ac_breakout: 105 time=0.165 throughput=1550.85 INFO:absl:[rank=1] None -> ac_breakout: 105 time=0.169 throughput=1518.86 INFO:absl:[rank=0] None -> ac_breakout: 109 time=2.363 throughput=108.32 INFO:absl:[rank=1] None -> ac_breakout: 109 time=2.405 throughput=106.46 INFO:absl:[rank=0] None -> ac_breakout: 115 time=0.177 throughput=1446.16 INFO:absl:[rank=1] None -> ac_breakout: 115 time=0.176 throughput=1453.41
The synchronization is there too. The trained result cannot be played yet (which is expected), will take a closer look on the checkpoints. Meanwhile, I will start to think about a cleaner implementation.
I found that after training for 10 minutes, they starts to behave as "not synchronized". I think I have some misunderstanding of how DDP works.
Had a discussion with @emailweixu while reading the DDP code, and we figured out why the above approach (directly wrapping Algorithm
) does not work.
The below steps demonstrates how DDP work in one iteration, assuming m
is the original module and
w = DDP(m)
is m
with DDP wrapper.
w
hijacks m
's forward()
output = w.forward()
is called, it will call m.forward()
under the hoodw.forward()
does something extra to register a _DDPSink
to the the returned output
output.backward()
is called, it will call _DDPSink
's backward
callbackbackward
callback in turn inject a "reducer callback" at the end of the current computation graph's computation. backward()
is done, the "reducer callback" is invoked that does the synchronization.This explains why wrapping Algorithm
won't work because we are not calling Algorithm
's forward()
. In fact it does not even have backward()
.
The next idea to try is to wrap over anything that produces train_info
, making train_info
part of the output of an forward()
. This is to trick DDP to do what we want.
Now with a ddp wrapper applied to the unroll()
of RLAlgorithm
, training ran successfully on 2 GPUs:
class RLAlgorithm:
def activate_ddp(self, rank: int):
self.__dict__['_unroll_performer'] = DDP(UnrollPerformer(self), device_ids=[rank])
Note that UnrollPerformer
is itself a nn.Module
whose forward
wrapps unroll
.
Verified that both GPUs are being utilized, and they are synchronized:
The remaining problem is that when turning on DDP, the time consumed for each training iteration is significantly increased. On the same physical machine, single process non-DDP, each iteration took around 130ms
, while in dual process DDP each iteration took 2.5 sec
or 2500ms
.
Preliminary investigation found that the backward()
(where synchronization is supposed to happen) for each iteration increased from 11ms
to 15ms
, pretty much very insignificant. unroll()
(which is wrapped by DDP) consumes about the same amount of time as well. Working on finding what explains the huge time consumption difference now.
I was able to further hunt down the cause. The major contributor of the more than 10x time consumption increase comes from
self.summarize_train(experience, train_info, loss_info, params)
In particular, I think this is because
obs = alf.nest.find_field(experience, "observation")
has a different dimension. Comparing dual-GPU DDP version and single GPU single process version, the shape of obs[0]
is respectively (I am using 32-batch environment)
dual:
(8, 32, 12, 210, 160)
single:(8, 32, 4, 84, 84)
Apparently some of the transformation was not applied in dual GPU version, which is supposed to downsample the observation from (12, 210, 160)
to (4, 84, 84)
. Keep debugging.
dual:
(8, 32, 12, 210, 160)
single:(8, 32, 4, 84, 84)
To me this is more like a bug when obtaining input tensors. usually we don't have a "downsampling" transformer from ALF. The env is directly responsible for resizing images. So probably you are using two different envs/wrappers. And the image channels is usually 3, or with FrameStacker, 3n.
dual:
(8, 32, 12, 210, 160)
single:(8, 32, 4, 84, 84)
To me this is more like a bug when obtaining input tensors. usually we don't have a "downsampling" transformer from ALF. The env is directly responsible for resizing images. So probably you are using two different envs/wrappers. And the image channels is usually 3, or with FrameStacker, 3n.
Thanks for the help, Haonan. I am slowly digging into that. Let me check the environment.
With some more debugging, I found that the problem is due to "failing to apply DMAtariPreprocessing
".
In atari_conf.py
, suite_gym.load
is configured with
alf.config(
'suite_gym.load',
gym_env_wrappers=[gym_wrappers.DMAtariPreprocessing],
# Default max episode steps for all games
#
# Per DQN paper setting, 18000 frames assuming frameskip = 4
max_episode_steps=4500)
With python debugger, I can see that
DMAtariPreprocessing
and 4500
are correctly passed inSo this is likely some configuration loading problem. I will need to read more on this to understand what's happening here.
With some more debugging, I found that the problem is due to "failing to apply
DMAtariPreprocessing
".In
atari_conf.py
,suite_gym.load
is configured withalf.config( 'suite_gym.load', gym_env_wrappers=[gym_wrappers.DMAtariPreprocessing], # Default max episode steps for all games # # Per DQN paper setting, 18000 frames assuming frameskip = 4 max_episode_steps=4500)
With python debugger, I can see that
- In single process single GPU setting, both
DMAtariPreprocessing
and4500
are correctly passed in- In dual process dual GPU setting (DDP), neither of them is set
So this is likely some configuration loading problem. I will need to read more on this to understand what's happening here.
And after a few hours of poking around and investigation, I finally find the problem why configuration is not respected. I'll summary my discovery here:
training_worker()
for each of GPUProcessEnvironment
in their own subprocessesProcessEnvironment
creates the environment in their own subprocesses with gym_suite.load()
Note that there are 2 hierarchies of sub processes here. In order for subprocess to inherit _CONFIG_TREE
(or any other module level variable) from the parent process, it needs to be started with start_method = fork
, because fork
(and seems only fork
) will copy the memory of allocated objects from parent process.
In single GPU setup, this is not a problem because there is only 1 hierarchy of sub processes and the top level process will start the environment processes with the default start method, which is fork
on Linux.
However, in the dual GPU/dual process setup, in order for DDP to work, the 2 training_worker
processes have to be started with spawn
. This also implicitly set the default start_method for their subprocesses to be spawn
. Environments processes (grand-children processes) started by training_worker
therefore implicitly lose the inheritance of _CONFIG_TREE
.
The solution is simple once we figured out above:
ctx = multiprocessing.get_context('fork')
self._process = ctx.Process(...)
With the above problem fixed, the training for ac_breakout
seems good now.
I haven't started working on the shenanigans of checkpoint/summary/metrics so the curves might look a bit messy, but it looks similar to the performance of a single GPU, within similar time. (Note that I am running 32 environments for each process in dual process dual GPU setup).
By turning on profiling = True
in TrainerConfig
, I have collected the profiling metrics for the ac_breakout
training:
_train_iter_on_policy
: 0.283 per call
unroll
: 0.216 per calltrain_from_unroll
: 0.065 per call_train_iter_on_policy
: 0.220 per call
_unroll
: 0.159 per calltrain_from_unroll
: 0.060 per call So actually DDP version is indeed faster, but not by a large margin. It uses about ~25%~ less time to train the same number of iterations, which aligns with the numbers from above.
However, I discovered another problem - when DDP is on, even though the log files are generated, nothing is written to them. Will need to look at this as well.
On-policy algorithm can now enjoy DDP. The next step is to add full support for off-policy as well. Closing this issue in favor of #1096.
As discussed with @emailweixu, it would be nice to have alf support multi-GPU training. The goals are