HorizonRobotics / alf

Agent Learning Framework https://alf.readthedocs.io
Apache License 2.0
298 stars 49 forks source link

Support Distributed Training in alf #913

Closed breakds closed 2 years ago

breakds commented 3 years ago

As discussed with @emailweixu, it would be nice to have alf support multi-GPU training. The goals are

hnyu commented 3 years ago

@breakds FYI, two reference papers I came across a while ago (RL scenario):

https://openreview.net/pdf?id=H1gX8C4YPr https://ai.googleblog.com/2020/03/massively-scaling-reinforcement.html

Although they are proposed for multi-machine training, our multi-gpu single-machine case is a special and simpler case.

Or refer to Pytorch official multi-gpu support (general DL scenario): https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

breakds commented 3 years ago

Thanks @hnyu for the references!

breakds commented 3 years ago

DD-PPO, the first paper's main idea is about early stopping the slow simulation during rollout with a batched environment (potentially distributed over different machines in a cluster), and try to use the full experience from some of the environment and partial experience from the early-stopped ones during the training in each iteration. I think we can borrow the ideas in the near future.

As the first step, I will look into how pytorch's DataParallel is implemented and use it (or similar techniques) to enable multi-GPU single-machine's

  1. Network(s) forward evaluation during rollout
  2. Network(s) forward/backward evaluation during training

in each of the training iteration.

breakds commented 3 years ago

Currently I am hitting two problems with DataParallel:

  1. When tried on a simple backward() operation, DataParallel version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel is pretty significant.
  2. Directly applying DataParallel on our network and it will crash. This is another thing that I am working on.
hnyu commented 3 years ago
  1. When tried on a simple backward() operation, DataParallel version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel is pretty significant.

I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?

breakds commented 3 years ago
  1. When tried on a simple backward() operation, DataParallel version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel is pretty significant.

I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?

Yep I think that is what happened. I was testing the forward and backward of a network like this:

class Network(nn.Module):
    def __init__(self, input_size, output_size):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 384)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(384, 64)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(64, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h = self.fc1(x)
        h = self.relu1(h)
        h = self.fc2(h)
        h = self.relu2(h)
        h = self.fc3(h)
        h = self.relu3(h)
        h = self.fc4(h)
        h = self.sigmoid(h)
        return h

I realized that this is probably too small because even if a batch of 25600 is passed in, it pretty much does not change in terms of the consumed time.

I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in alf.

hnyu commented 3 years ago
  1. When tried on a simple backward() operation, DataParallel version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel is pretty significant.

I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?

Yep I think that is what happened. I was testing the forward and backward of a network like this:

class Network(nn.Module):
    def __init__(self, input_size, output_size):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 384)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(384, 64)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(64, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h = self.fc1(x)
        h = self.relu1(h)
        h = self.fc2(h)
        h = self.relu2(h)
        h = self.fc3(h)
        h = self.relu3(h)
        h = self.fc4(h)
        h = self.sigmoid(h)
        return h

I realized that this is probably too small because even if a batch of 25600 is passed in, it pretty much does not change in terms of the consumed time.

I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in alf.

Our expected scenario for multi-gpu is image inputs with a large batch size. So you could try dummy image inputs instead.

Besides running time, also another scenario is to split sgd memory consumption into multiple cards, if one card is not enough.

breakds commented 3 years ago
  1. When tried on a simple backward() operation, DataParallel version (2 GPU) is taking 1.8 seconds while the sinlge GPU version only takes 0.5 seconds, and I haven't figured out why. I think this at least suggest the overhead of DataParallel is pretty significant.

I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup?

Yep I think that is what happened. I was testing the forward and backward of a network like this:

class Network(nn.Module):
    def __init__(self, input_size, output_size):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 384)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(384, 64)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(64, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h = self.fc1(x)
        h = self.relu1(h)
        h = self.fc2(h)
        h = self.relu2(h)
        h = self.fc3(h)
        h = self.relu3(h)
        h = self.fc4(h)
        h = self.sigmoid(h)
        return h

I realized that this is probably too small because even if a batch of 25600 is passed in, it pretty much does not change in terms of the consumed time. I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in alf.

Our expected scenario for multi-gpu is image inputs with a large batch size. So you could try dummy image inputs instead.

Besides running time, also another scenario is to split sgd memory consumption into multiple cards, if one card is not enough.

That makes a lot of sense. Thanks for the suggestions and clarification!

breakds commented 3 years ago

I was using ActorDistributionNetwork with a batch of random generated images to run the experiment, and got

Traceback (most recent call last):
  File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/breakds/projects/alf/alf/bin/experiment/dp_network_experiment.py", line 42, in <module>
    action_distribution, actor_state = actor_network(observation, state=())
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    return self.gather(outputs, self.output_device)
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 180, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 76, in gather
    res = gather_map(outputs)
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'Categorical' object is not iterable

With some investigation, I realized that it fails because DataParallel internally does not know how to combine Categorical objects, which is the output of ActorDistributionNetwork. DataParallel works in two steps

  1. Scatter (think of map)
  2. Gather (think of reduce)

And this problem happens at the last step of "Gather". I will use a slightly modified network to continue experiment to work around this.

However, the final solution should make multi-GPU as transparent as possible so that it is convenient to use.

Directly applying DataParallel may not be the right solution we are looking for partly because of the above issue. This is something to think about later.

breakds commented 3 years ago

After slightly modifying the ActorDistributionNetwork (for experiment purpose), I was able to run DataParallel with 2 x 3080:

import torch
import torch.nn as nn
import alf
from alf.networks import ActorDistributionNetwork
from alf.tensor_specs import BoundedTensorSpec
import functools
import time

if __name__ == '__main__':
    alf.set_default_device('cuda')

    CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1))

    actor_network_cls = functools.partial(
        ActorDistributionNetwork,
        fc_layer_params=(512, ),
        conv_layer_params=CONV_LAYER_PARAMS)

    actor_network = nn.DataParallel(actor_network_cls(
        input_tensor_spec=BoundedTensorSpec(
            shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.),
        action_spec=BoundedTensorSpec(
            shape=(), dtype=torch.int64, minimum=0, maximum=3)))

    start_time = time.time()
    for i in range(1000):
        observation = torch.rand(640, 4, 150, 150)
        action_distribution, actor_state = actor_network(observation, state=())
    print(f'{time.time() - start_time} seconds elapsed')

I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without DataParallel:

  1. The memory consumption on both card together is significantly > single card non-data-parallel version
  2. The non-DataParallel version took 6s to finish on single card. The DataParallel version took 1 minutes.

This almost rendered DataParallel not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.

hnyu commented 3 years ago

After slightly modifying the ActorDistributionNetwork (for experiment purpose), I was able to run DataParallel with 2 x 3080:

import torch
import torch.nn as nn
import alf
from alf.networks import ActorDistributionNetwork
from alf.tensor_specs import BoundedTensorSpec
import functools
import time

if __name__ == '__main__':
    alf.set_default_device('cuda')

    CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1))

    actor_network_cls = functools.partial(
        ActorDistributionNetwork,
        fc_layer_params=(512, ),
        conv_layer_params=CONV_LAYER_PARAMS)

    actor_network = nn.DataParallel(actor_network_cls(
        input_tensor_spec=BoundedTensorSpec(
            shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.),
        action_spec=BoundedTensorSpec(
            shape=(), dtype=torch.int64, minimum=0, maximum=3)))

    start_time = time.time()
    for i in range(1000):
        observation = torch.rand(640, 4, 150, 150)
        action_distribution, actor_state = actor_network(observation, state=())
    print(f'{time.time() - start_time} seconds elapsed')

I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without DataParallel:

  1. The memory consumption on both card together is significantly > single card non-data-parallel version
  2. The non-DataParallel version took 6s to finish on single card. The DataParallel version took 1 minutes.

This almost rendered DataParallel not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.

The inefficiency of DataParallel seems unreasonable. There must be something wrong going on.

breakds commented 3 years ago

After slightly modifying the ActorDistributionNetwork (for experiment purpose), I was able to run DataParallel with 2 x 3080:

import torch
import torch.nn as nn
import alf
from alf.networks import ActorDistributionNetwork
from alf.tensor_specs import BoundedTensorSpec
import functools
import time

if __name__ == '__main__':
    alf.set_default_device('cuda')

    CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1))

    actor_network_cls = functools.partial(
        ActorDistributionNetwork,
        fc_layer_params=(512, ),
        conv_layer_params=CONV_LAYER_PARAMS)

    actor_network = nn.DataParallel(actor_network_cls(
        input_tensor_spec=BoundedTensorSpec(
            shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.),
        action_spec=BoundedTensorSpec(
            shape=(), dtype=torch.int64, minimum=0, maximum=3)))

    start_time = time.time()
    for i in range(1000):
        observation = torch.rand(640, 4, 150, 150)
        action_distribution, actor_state = actor_network(observation, state=())
    print(f'{time.time() - start_time} seconds elapsed')

I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without DataParallel:

  1. The memory consumption on both card together is significantly > single card non-data-parallel version
  2. The non-DataParallel version took 6s to finish on single card. The DataParallel version took 1 minutes.

This almost rendered DataParallel not usable. Though I will continue investigate to see why such odd behavior exists. Will discuss with people with more experience in this tomorrow.

The inefficiency of DataParallel seems unreasonable. There must be something wrong going on.

Or maybe this is by design, I can try to look into where the time is being spent.

emailweixu commented 3 years ago

According to https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, DataParallel might be even slower than DistributedDataParallel

breakds commented 3 years ago

According to https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, DataParallel might be even slower than DistributedDataParallel

Yep, I can see that GIL issue makes sense. DistributedDataParallel is even harder to integrate - if we are willing to spend more effort it would probably be better to roll our own solution that suits us better.

@hnyu and I chatted about this today, and I agree with Haonan that we might want to adjust our goal and go for a slightly more complicated (i.e. might require structural update) customized solution. We can chat more about this tomorrow.

breakds commented 3 years ago
  1. Successfully running DDP on the 2-GPU machine with ActorDistributionNetwork. Preliminary result shows about 25% performance improvement vs the non-parallel version (this is only a single data point, because it is from that 2-GPU machine).
  2. Experiment on running DDP with alf. There are definitely a lot of caveats. I am very close to get ac_breakout running, but at this moment I need to resolve #930 first.
breakds commented 3 years ago

After working around the above issue, I ran into another blocker of distributed data parallel.

On the surface the problem is that the DDP wrapper will get stuck. After debugging with pudb, the problem is that the underlying reducer is throwing exception.

exception

All online discussion about this lead to this pytorch issue, which seems to be a well-known one. To summarize, the current implementation of DDP assumes all parameters in a module to be used. If there are unused parameters, the reducer will fail to communicate their gradients to the other processes.

breakds commented 3 years ago

Document possible next steps (I am a bit reluctant of taking any path yet):

  1. Relying on third-party package apex. It is said that apex will eventually be merged into pytorch, but not likely any sooner than next version of pytorch.
  2. Or, we bring in apex in as a dependency. My fear is that by bringing in another blackbox (which is much less well maintained than pytorch itself) we will be getting more problems rather than features.
  3. Or, try to hunt down all the unused parameters and get rid of them. I am afraid most unused parameters are introduced for a reason, and in certain cases they are just "unused" under certain configuration. Besides, forcing out unused parameters seems too restrictive and unfriendly to the developers.
  4. Or, roll out our own version of distributed building, probably base it on nccl. We will have to pay the cost of major design and refactor - totally depend on whether the benefit worth it.
emailweixu commented 3 years ago

Which algorithm are you trying to run with DDP? For some algorithm (e.g DDPG, SAC), there are parameters which are not handled by optimizer and hence there are no gradient for them. For ActorCriticAlgorithm, there shouldn't be such parameters. Algorithm.get_unoptimized_parameter_info() can be used to get all the parameters which are not handled by optimizers.

breakds commented 3 years ago

Thanks @emailweixu !

After discussion with Wei, I realized that the conclusion on "unused parameters" is rather premature. The exception of "incompatible function argument" happens when a C/C++ function is called from python, but the passed argument cannot be converted.

In fact, in the above exception, the function at fault is broadcast_coalesced. It is obvious that the problems happens when we passed a list which cannot be converted to TensorList for the 2nd argument.

Further tracing the program with debugger confirmed that:

debug

Normally, optimizer is not part of the module and so that it does not get into the module's state_dict. However, in our case optimizer is in the state_dict, so that it gets into the list to be broadcast. However, it is not a Tensor, so that it breaks the type enforcement here in the C++ function.

Knowing that the solution could be:

  1. The cleaner way might be get optimizers out of the state_dict. I need to think about it to further confirm whether the state of the optimizer needs to be in checkpoint (if so, we probably want them in state_dict). At this moment my preliminary conclusion is that we need them in state_dict because if we resume training from the checkpoint, the state in optimizer matters.
  2. The other way is to use the self.parameters_to_ignore in DDP. I just need to figure out a way to add optimizer in.
emailweixu commented 3 years ago

One way to solve the problem is to restore the semantics of Module.state_dict to not including optimizer state. And add another pair of functions to Algorithm (e.g. optimizer_state_dict, load_optimizer_state_dict) to handle the state of optimizer. checkpoint_util also needs to be changed accordingly. cc. @Haichao-Zhang since he wrote this part originally.

Knowing that the solution could be:

  1. The cleaner way might be get optimizers out of the state_dict. I need to think about it to further confirm whether the state of the optimizer needs to be in checkpoint (if so, we probably want them in state_dict). At this moment my preliminary conclusion is that we need them in state_dict because if we resume training from the checkpoint, the state in optimizer matters.
  2. The other way is to use the self.parameters_to_ignore in DDP. I just need to figure out a way to add optimizer in.
breakds commented 3 years ago

Thanks @emailweixu ! I was able to bypass the problem by adding the optimizers into the ignore list (there is a hidden API of DDP). Just to be safe, I will add a check before DDP wrapper to make sure that all of the stuff in state_dict is Tesnor, and if not, add all of them into the ignore list.

Haichao-Zhang commented 3 years ago

Thanks @emailweixu ! I was able to bypass the problem by adding the optimizers into the ignore list (there is a hidden API of DDP). Just to be safe, I will add a check before DDP wrapper to make sure that all of the stuff in state_dict is Tesnor, and if not, add all of them into the ignore list.

Seems that you have figured out a solution. Let me know if you need any further discussions on this.

breakds commented 3 years ago

After some more investigation and messing around, I was able to run alf training with DDP. There are a lot of places where we need to get the device number right. Currently I can observe that it is using both GPUs (in terms of memory), but only one of them is actually running (the other seems idle to me). A snapshot of nvidia-smi showed:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01    Driver Version: 460.73.01    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce RTX 3080    Off  | 00000000:04:00.0 Off |                  N/A |
|  0%   56C    P2   107W / 340W |   3558MiB / 10015MiB |     27%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3080    Off  | 00000000:09:00.0 Off |                  N/A |
|  0%   45C    P8    15W / 340W |   1755MiB / 10018MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1610      G   ...xorg-server-1.20.11/bin/X        9MiB |
|    0   N/A  N/A      1687      G   ...hell-40.1/bin/gnome-shell        8MiB |
|    0   N/A  N/A   1794885      C   ...thon3-3.8.9/bin/python3.8     2217MiB |
|    0   N/A  N/A   1794886      C   ...thon3-3.8.9/bin/python3.8     1319MiB |
|    1   N/A  N/A      1610      G   ...xorg-server-1.20.11/bin/X        4MiB |
|    1   N/A  N/A   1794886      C   ...thon3-3.8.9/bin/python3.8     1747MiB |
+-----------------------------------------------------------------------------+

From the log I can clearly see one of the 2 processes is running and producing log. This is weird because based on my understanding of how DDP works, one process will have to wait for another before it can go on when hitting backward(). Will keep debugging along this direction.

breakds commented 3 years ago

Turns out that my hypothesis on gpu 0 and gpu 1 run sequentially proved to be wrong. I made 2 mistakes in my toy example with ActionDistributionNetwork:

  1. Did not call backward() so that the 2 processes are not synchronized except when they are initialized.
  2. Did not run enough number of iterations - turns out that device 1 takes a few seconds to warm up (not sure why).

Good news is that after the fix the toy example

    start_time = time.time()
    for i in range(2500):
        observation = torch.rand(batch_size, 4, 84, 84, device=rank)
        action_distribution, actor_state = actor_network(observation, state=())
        action = action_distribution.sample()
        reward = torch.rand(batch_size, device=rank)
        loss = - torch.mean(action_distribution.log_prob(action) * reward)
        loss.backward()
        if i % 100 == 0:
            print(f'iteration {i} - {time.time() - start_time} seconds elapsed on device {rank}')
    print(f'{time.time() - start_time} seconds elapsed on device {rank}')

proves that when batch_size is big enough (in my comparison, batch size is divided by the number of GPUs for multi-gpu run), multi-gpu has some gain in terms of time elapsed:

Single GPU: 31.069265604019165 seconds elapsed on device 0

Double GPU:
24.037397384643555 seconds elapsed on device 1
24.038329362869263 seconds elapsed on device 0

for batch size = 1024 per batch. Bad news is that this does not lead to why running alf.bin.train leads to only one process progressing. Keep investigating.

hnyu commented 3 years ago

I think for this toy example, you can also try more complex CNN architectures like (ResnetEncodingNetwork) to see the gain.

breakds commented 3 years ago

Turns out the problem is still one of the two processes throws an exception, but that exception is not observed.

About "the exception is not observed"

Actually I researched (experimented) on how exception in nested sub processes works yesterday, and thought I had solved this problem. Sadly there are still certain cases such exception just raised silently. Normally I would expect it to show in terminal because I explicitly catch it and print it in the offending process.

About the exception itself

The exception itself looks like this:

File "/home/breakds/projects/alf/alf/utils/tensor_utils.py", line 87, in tensor_extend_zero
    return torch.cat((x, torch.zeros(1, *x.shape[1:], dtype=x.dtype)))
RuntimeError: All input tensors must be on the same device. Received cuda:1 and cuda:0

After going DDP, we would like to have newly created tensors to be placed on a process-dependent default rank (a.k.a. device id). I have patched quite a few places with to(rank) but turns out that there are too many occurrences to cover (because I want to get this working before commitment on making them clean and nice). Will have to go back to research on process-dependent default rank.

Update

torch.cuda.set_device(rank) seems to do the trick. But the behavior is still only device 0's process progressing. Keep investigating.

breakds commented 3 years ago

I think for this toy example, you can also try more complex CNN architectures like (ResnetEncodingNetwork) to see the gain.

Acknowledged. Will try it later. Thanks!

breakds commented 3 years ago

More updates, with some other small problems fixed, I was now able to train with 2 GPU under DDP wrapper:

INFO:absl:[rank=0] None -> ac_breakout: 79 time=2.391 throughput=107.07
INFO:absl:[rank=1] None -> ac_breakout: 79 time=2.346 throughput=109.14
INFO:absl:[rank=1] None -> ac_breakout: 85 time=0.185 throughput=1385.38
INFO:absl:[rank=0] None -> ac_breakout: 85 time=0.191 throughput=1341.97
INFO:absl:[rank=0] None -> ac_breakout: 89 time=2.386 throughput=107.31
INFO:absl:[rank=1] None -> ac_breakout: 89 time=2.463 throughput=103.93
INFO:absl:[rank=1] None -> ac_breakout: 95 time=0.177 throughput=1444.35
INFO:absl:[rank=0] None -> ac_breakout: 95 time=0.178 throughput=1436.04
INFO:absl:[rank=0] None -> ac_breakout: 99 time=2.333 throughput=109.75
INFO:absl:[rank=1] None -> ac_breakout: 99 time=2.446 throughput=104.67
INFO:absl:[rank=0] None -> ac_breakout: 105 time=0.165 throughput=1550.85
INFO:absl:[rank=1] None -> ac_breakout: 105 time=0.169 throughput=1518.86
INFO:absl:[rank=0] None -> ac_breakout: 109 time=2.363 throughput=108.32
INFO:absl:[rank=1] None -> ac_breakout: 109 time=2.405 throughput=106.46
INFO:absl:[rank=0] None -> ac_breakout: 115 time=0.177 throughput=1446.16
INFO:absl:[rank=1] None -> ac_breakout: 115 time=0.176 throughput=1453.41

The synchronization is there too. The trained result cannot be played yet (which is expected), will take a closer look on the checkpoints. Meanwhile, I will start to think about a cleaner implementation.

breakds commented 3 years ago

Outline for plan of next steps, after discussion on 2021/07/23:

TODO Productionize DDP over ALF [3/11]

  1. [X] Update =train.py=

  2. [x] How does DDP guarantee that all optimizer.step() are synchronized?

    • Per the documentation of DDP: Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers
    • This means that we might want to broadcast parameters after optimizer.step in one process instead of having to run optmizer.step() on multiple processes doing (supposedly) exactly the same work.
  3. [x] Figure out how DDP establishes backward callbacks. This proves whether wrapping algorithm is enough.

  4. [x] Only rank 0 process writes checkpoint

  5. [x] Can we still run =play.py=?

  6. [ ] Does replay buffer synchronize over DDP?

    • Does this affect prioritized replay buffer?
  7. [ ] Tensorboard: any metrics changed their meaning?

  8. [x] How is batchnorm affected? Do we have to use the ~SyncBatchNorm~ as mentioned in the source code of DDP?

  9. [x] Debugging and =Ctrl-C= handling

  10. [x] Figuring out why staring environments takes longer in DDP mode

  11. [ ] Figuring out why DDP makes each iteration slower

breakds commented 3 years ago

More updates, with some other small problems fixed, I was now able to train with 2 GPU under DDP wrapper:

INFO:absl:[rank=0] None -> ac_breakout: 79 time=2.391 throughput=107.07
INFO:absl:[rank=1] None -> ac_breakout: 79 time=2.346 throughput=109.14
INFO:absl:[rank=1] None -> ac_breakout: 85 time=0.185 throughput=1385.38
INFO:absl:[rank=0] None -> ac_breakout: 85 time=0.191 throughput=1341.97
INFO:absl:[rank=0] None -> ac_breakout: 89 time=2.386 throughput=107.31
INFO:absl:[rank=1] None -> ac_breakout: 89 time=2.463 throughput=103.93
INFO:absl:[rank=1] None -> ac_breakout: 95 time=0.177 throughput=1444.35
INFO:absl:[rank=0] None -> ac_breakout: 95 time=0.178 throughput=1436.04
INFO:absl:[rank=0] None -> ac_breakout: 99 time=2.333 throughput=109.75
INFO:absl:[rank=1] None -> ac_breakout: 99 time=2.446 throughput=104.67
INFO:absl:[rank=0] None -> ac_breakout: 105 time=0.165 throughput=1550.85
INFO:absl:[rank=1] None -> ac_breakout: 105 time=0.169 throughput=1518.86
INFO:absl:[rank=0] None -> ac_breakout: 109 time=2.363 throughput=108.32
INFO:absl:[rank=1] None -> ac_breakout: 109 time=2.405 throughput=106.46
INFO:absl:[rank=0] None -> ac_breakout: 115 time=0.177 throughput=1446.16
INFO:absl:[rank=1] None -> ac_breakout: 115 time=0.176 throughput=1453.41

The synchronization is there too. The trained result cannot be played yet (which is expected), will take a closer look on the checkpoints. Meanwhile, I will start to think about a cleaner implementation.

Update

I found that after training for 10 minutes, they starts to behave as "not synchronized". I think I have some misunderstanding of how DDP works.

breakds commented 3 years ago

Had a discussion with @emailweixu while reading the DDP code, and we figured out why the above approach (directly wrapping Algorithm) does not work.

How DDP works

The below steps demonstrates how DDP work in one iteration, assuming m is the original module and

w = DDP(m)

is m with DDP wrapper.

  1. w hijacks m's forward()
  2. When output = w.forward() is called, it will call m.forward() under the hood
  3. However, w.forward() does something extra to register a _DDPSink to the the returned output
  4. Later when output.backward() is called, it will call _DDPSink's backward callback
  5. That backward callback in turn inject a "reducer callback" at the end of the current computation graph's computation.
  6. Therefore, when the whole backward() is done, the "reducer callback" is invoked that does the synchronization.

This explains why wrapping Algorithm won't work because we are not calling Algorithm's forward(). In fact it does not even have backward().

The next idea to try is to wrap over anything that produces train_info, making train_info part of the output of an forward(). This is to trick DDP to do what we want.

breakds commented 3 years ago

Now with a ddp wrapper applied to the unroll() of RLAlgorithm, training ran successfully on 2 GPUs:

class RLAlgorithm:
    def activate_ddp(self, rank: int):
        self.__dict__['_unroll_performer'] = DDP(UnrollPerformer(self), device_ids=[rank])

Note that UnrollPerformer is itself a nn.Module whose forward wrapps unroll.

Verified that both GPUs are being utilized, and they are synchronized:

  1. Corresponding iterations finished at almost the same time for both processes
  2. If manually pause one process, the other pauses at the same training iteration
breakds commented 3 years ago

The remaining problem is that when turning on DDP, the time consumed for each training iteration is significantly increased. On the same physical machine, single process non-DDP, each iteration took around 130ms, while in dual process DDP each iteration took 2.5 sec or 2500ms.

Preliminary investigation found that the backward() (where synchronization is supposed to happen) for each iteration increased from 11ms to 15ms, pretty much very insignificant. unroll() (which is wrapped by DDP) consumes about the same amount of time as well. Working on finding what explains the huge time consumption difference now.

breakds commented 3 years ago

I was able to further hunt down the cause. The major contributor of the more than 10x time consumption increase comes from

        self.summarize_train(experience, train_info, loss_info, params)

In particular, I think this is because

            obs = alf.nest.find_field(experience, "observation")

has a different dimension. Comparing dual-GPU DDP version and single GPU single process version, the shape of obs[0] is respectively (I am using 32-batch environment)

dual: (8, 32, 12, 210, 160) single: (8, 32, 4, 84, 84)

Apparently some of the transformation was not applied in dual GPU version, which is supposed to downsample the observation from (12, 210, 160) to (4, 84, 84). Keep debugging.

hnyu commented 3 years ago

dual: (8, 32, 12, 210, 160) single: (8, 32, 4, 84, 84)

To me this is more like a bug when obtaining input tensors. usually we don't have a "downsampling" transformer from ALF. The env is directly responsible for resizing images. So probably you are using two different envs/wrappers. And the image channels is usually 3, or with FrameStacker, 3n.

breakds commented 3 years ago

dual: (8, 32, 12, 210, 160) single: (8, 32, 4, 84, 84)

To me this is more like a bug when obtaining input tensors. usually we don't have a "downsampling" transformer from ALF. The env is directly responsible for resizing images. So probably you are using two different envs/wrappers. And the image channels is usually 3, or with FrameStacker, 3n.

Thanks for the help, Haonan. I am slowly digging into that. Let me check the environment.

breakds commented 3 years ago

With some more debugging, I found that the problem is due to "failing to apply DMAtariPreprocessing".

In atari_conf.py, suite_gym.load is configured with

alf.config(
    'suite_gym.load',
    gym_env_wrappers=[gym_wrappers.DMAtariPreprocessing],
    # Default max episode steps for all games
    #
    # Per DQN paper setting, 18000 frames assuming frameskip = 4
    max_episode_steps=4500)

With python debugger, I can see that

  1. In single process single GPU setting, both DMAtariPreprocessing and 4500 are correctly passed in
  2. In dual process dual GPU setting (DDP), neither of them is set

So this is likely some configuration loading problem. I will need to read more on this to understand what's happening here.

breakds commented 3 years ago

With some more debugging, I found that the problem is due to "failing to apply DMAtariPreprocessing".

In atari_conf.py, suite_gym.load is configured with

alf.config(
    'suite_gym.load',
    gym_env_wrappers=[gym_wrappers.DMAtariPreprocessing],
    # Default max episode steps for all games
    #
    # Per DQN paper setting, 18000 frames assuming frameskip = 4
    max_episode_steps=4500)

With python debugger, I can see that

  1. In single process single GPU setting, both DMAtariPreprocessing and 4500 are correctly passed in
  2. In dual process dual GPU setting (DDP), neither of them is set

So this is likely some configuration loading problem. I will need to read more on this to understand what's happening here.

And after a few hours of poking around and investigation, I finally find the problem why configuration is not respected. I'll summary my discovery here:

  1. Main process started two training_worker() for each of GPU
  2. Each training worker then starts a bunch of ProcessEnvironment in their own subprocesses
  3. The ProcessEnvironment creates the environment in their own subprocesses with gym_suite.load()

Note that there are 2 hierarchies of sub processes here. In order for subprocess to inherit _CONFIG_TREE (or any other module level variable) from the parent process, it needs to be started with start_method = fork, because fork (and seems only fork) will copy the memory of allocated objects from parent process.

In single GPU setup, this is not a problem because there is only 1 hierarchy of sub processes and the top level process will start the environment processes with the default start method, which is fork on Linux.

However, in the dual GPU/dual process setup, in order for DDP to work, the 2 training_worker processes have to be started with spawn. This also implicitly set the default start_method for their subprocesses to be spawn. Environments processes (grand-children processes) started by training_worker therefore implicitly lose the inheritance of _CONFIG_TREE.

The solution is simple once we figured out above:

        ctx = multiprocessing.get_context('fork')
        self._process = ctx.Process(...)
breakds commented 3 years ago

With the above problem fixed, the training for ac_breakout seems good now.

train_ddp

I haven't started working on the shenanigans of checkpoint/summary/metrics so the curves might look a bit messy, but it looks similar to the performance of a single GPU, within similar time. (Note that I am running 32 environments for each process in dual process dual GPU setup).

breakds commented 3 years ago

By turning on profiling = True in TrainerConfig, I have collected the profiling metrics for the ac_breakout training:

  1. Single Process
    • _train_iter_on_policy: 0.283 per call
      • unroll: 0.216 per call
      • train_from_unroll: 0.065 per call
  2. Dual Process DDP (Measuring the master process)
    • _train_iter_on_policy: 0.220 per call
      • _unroll: 0.159 per call
      • train_from_unroll: 0.060 per call

So actually DDP version is indeed faster, but not by a large margin. It uses about ~25%~ less time to train the same number of iterations, which aligns with the numbers from above.

breakds commented 3 years ago

However, I discovered another problem - when DDP is on, even though the log files are generated, nothing is written to them. Will need to look at this as well.

breakds commented 2 years ago

On-policy algorithm can now enjoy DDP. The next step is to add full support for off-policy as well. Closing this issue in favor of #1096.