vacancy / Synchronized-BatchNorm-PyTorch

Synchronized Batch Normalization implementation in PyTorch.
MIT License
1.5k stars 189 forks source link

Wired things happen when applied to FPN #3

Open ShuLiu1993 opened 6 years ago

ShuLiu1993 commented 6 years ago

Hi, thanks a lot for your code.

But when I apply this code to my implemented e2e version of FPN, some wired things happen.

If I use 8 cards, the GPU memory continues to increase until one card occupies all GPU memory. Then the FPN get stuck and the GPU utilization of other not-full cards are 100%.

If I use 4 cards, the memory continue to increase and then FPN may get struck with all GPU utilization as 0.

I used PyTorch 0.4.0 with your DataParallelWithCallback and the input image size is different on different cards. And if I use BN from official pytorch, my code works well.

Could you pls give me any hints to help me to find the reason?

vacancy commented 6 years ago

Hi @ShuLiu1993 thank you for your report.

In my guess, there might multiple reasons for this.

  1. This code is tested for PyTorch 0.3.1. While 0.4 is still a very new release and there are some changes breaking the backward compatibility, I am not sure if this code works perfectly with 0.4. Can you try to switch your PyTorch version to 0.3.1 temporarily?
  2. This code has been tested for real-world applications such as semantic segmentation (See https://github.com/CSAILVision/semantic-segmentation-pytorch). I am not sure about your own implementation of FPN. Is it possible that something in your code is not compatible with the SyncBN?

If possible, could you please provide a minimal example reproducing the result you encountered? It will be extremely helpful for finding out potential compatibility issue or bugs.

ShuLiu1993 commented 6 years ago

Hi @vacancy,

Thanks a lot for your reply.

I solved the "memory leaking issue" by cleaning the environment and using the code from repo (https://github.com/CSAILVision/semantic-segmentation-pytorch), although they seem very similar.

But I still may get stuck after several iterations. When get stuck, the GPU utilization is 0% and the CPU utilization is also fairly low. Still cannot find the reason.

The environment I use now is pytorch 0.4.0 + python 3.6, same as the one used in repo (https://github.com/CSAILVision/semantic-segmentation-pytorch)

vacancy commented 6 years ago

Hi @ShuLiu1993

The repo semantic-segmentation actually does not use the official release of 0.4. Their code was based on a pre-release (a github) version of pytorch. Could you please try also 0.3.1? You can install it easily with conda by some command like conda install pytorch=0.3.1.

I am sorry that I can not reproduce the results you mentioned as I am not working on exactly your model. If it is convenient, could you please provide me with a minimal reproduction code? It will be also appreciated if you can instead help me debug out the issue.

If you can make sure that it is SyncBatchNorm that produce the "stuck", you can try to insert a try-catch block in SyncBatchNorm's forward function. By doing so you can better inspect the exceptions.

ShuLiu1993 commented 6 years ago

Hi @vacancy

Thanks for the reply.

I tried pytorch 0.3.1 by installing with conda. But this always gives me "Assertion pos >= 0 && pos < buffer.size() failed.". My labmates have already successfully used official pytorch 0.4.0 to reproduce the repo semantic-segmentation.

I am trying to localize the issue. I may share a minimal reproduction code when I localize the problem. Any suggestions that may help the debug process?

vacancy commented 6 years ago

@ShuLiu1993 I see. The bug in PyTorch 0.3.1 is due to a bug reported in this issue.

Thank you so much for the help. For the debug, could you first check if there is any error reported in forward? Due to the python implementation, exceptions from subthreads may not be throw out. I recommend you to wrap the forward function manually with a try-catch block to inspect if there is any error occurred.

A sample code is:

def forward(self, input):
    try:
        self._forward(input)
    except Exception as e:
        logger.exception('Error occurred.')

def _forward(self, input):
    # original codes goes here.

You can also rewrite the code as a pythonic method decorator.

leafxx commented 6 years ago

@ShuLiu1993 Hi,have you solved this problem? I also get stuck after several iterations and all dataloder threads are dead is this problem related with roi pooling implemented in pytorch? if you use roi pooling in your codes?

chauhochow commented 6 years ago

I also get stuck after several iterations. My pytorch version is 0.4.1. I just try to use resnet101 in the segmentation task

vacancy commented 6 years ago

@chauhochow Sorry to bother you. But can anyone provide a sample code (hopefully minimalist) that can reproduce the "stuck"?

Recently I am not using this module. Thus a sample code from you can be extremely helpful for us to identify the reason. Thanks ahead!

vacancy commented 6 years ago

Hi @ShuLiu1993 @leafxx @chauhochow,

When I was debugging some codes of someone else, I realized that your code getting stuck is probably attributed to some control-flow in your implementation. Recall one of the keynote I wrote on the README file:

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

In short, if you are running your code on 4 GPUs (for example), all 4 replicas of your model must call the batchnorm module in a synchronized manner. Otherwise, it's difficult to define what's "synchronization".

There are multiple kinds of usage may violate this rule. Example 1:

if some_condition_function(input):
    output = module_contains_batchnorm(input)
else:
    output = another_module_contains_batchnorm(input)
# this kind of conditions exists in some detection-related implementations.
if some_early_break_condition(input):
    return 0

output = module_contains_batchnorm(input)
# due to some "buggy" thing in old python versions (e.g., python 3.5),
# sometimes exceptions from child threads can not be thrown out. 
# In such cases, some child threads may die due to some bugs in your code
# and the main thread will get stuck.
output = some_function_may_trigger_error(input)
output = module_contains_batchnorm(output)