silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Distributed Data Parallel #44

Closed lucidrains closed 4 years ago

lucidrains commented 4 years ago

Hi Silvan again!

I have a quick question! Have you tested your framework in a distributed setting? I am currently stuck on another implementation of reversible nets, and will consider yours if it is reportedly working with memcnn.

Secondly, I also need my forward passes to be deterministic, and I am considering submitting a PR so your framework can benefit too https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py#L7 It is the same approach used by Pytorch's checkpointing. But first I wanted to see if you were open to that possibility.

Phil

silvandeleemput commented 4 years ago

Hi Silvan again!

Hi Phil, welcome back!

I have a quick question! Have you tested your framework in a distributed setting? I am currently stuck on another implementation of reversible nets, and will consider yours if it is reportedly working with memcnn.

I haven't tested it yet for a distributed setting myself, but I think it should work in principle. The underlying implementation of the InvertibleModuleWrapper (the module implementing the memory savings) is very similar to how checkpointing is implemented, yet it saves more memory by also discarding the input and reconstructing it on the backward pass using the inverse. I would certainly encourage you to have a go at it and to let me know how it goes.

Secondly, I also need my forward passes to be deterministic, and I am considering submitting a PR so your framework can benefit too https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py#L7 It is the same approach used by Pytorch's checkpointing. But first I wanted to see if you were open to that possibility.

I am definitely interested in that as well. I think it would be great if a deterministic option was available to MemCNN. Ideally, I would like to add this feature to the InvertibleCheckpointFunction with a similar interface to the checkpoint function, so something like a: preserve_rng_state argument. So in principle, I would welcome this PR!

Note, that at the moment there is also work being done on the InvertibleCheckpointFunction for adding support for multiple inputs and outputs (see #43). You can wait for that to be done, or just submit the PR and let me handle the expected merge conflicts (fine with me either way).

-Sil

lucidrains commented 4 years ago

@silvandeleemput Awesome! I'll shelve this for now and I will be back :)

naga-karthik commented 4 years ago

@lucidrains Just wanted to know whether you have been able to run memcnn in a distributed setting? I am also stuck here as I tried to run RevTorch with DDP and I am getting some strange errors. If you could make it work using memcnn, I thought I could do the same!

ibro45 commented 4 years ago

@naga-karthik what kind of errors are you getting? I've been using memcnn with PyTorch DDP without a problem.

naga-karthik commented 4 years ago

@ibro45 thanks for letting me know! Oh, no I have not used memcnn yet. I was having some issues with DDP while using the RevTorch library. I just wanted to know whether it works with memcnn so that I could get dig into it. I will get back if I have any problems!

ibro45 commented 4 years ago

Oh sorry, I misread RevTorch as RevNet 😅

naga-karthik commented 4 years ago

Hi @ibro45, I ran into a strange error, I was hoping to get some help on it. Did you ever get this error while running memcnn with DDP - RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 192, 4, 17, 9]], which is output 0 of LeakyReluBackward1, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). I am trying to run an adapted version of RevGAN, with different downsampling and upsampling operations. However, RevGAN uses a deprecated version of memcnn (the one which does not have an InvertibleModuleWrapper function), and also runs on a previous version of PyTorch (< 1.0). I believe I have made changes as per the latest version of memcnn. Do you have an idea as to where the cause of this could be? Many thanks in advance!

ibro45 commented 4 years ago

Hey, I had the same problem and the solution for me was to set inplace=False in (Leaky)Relus that are inside of the InvertibleModuleWrapper.

@silvandeleemput I don't know if this is supposed to happen in DDP though, as keeping inplace=True will work with memcnn when DDP is not used. If you'd like I can submit a new issue with a minimal example. @naga-karthik could you also confirm that this was the case for you too?

naga-karthik commented 4 years ago

Hey @ibro45, thanks for your quick response! I tried setting inplace=False for the ReLU layer inside my InvertibleModuleWrapper function. Only making this change did not solve the error for me. However, when I changed all LeakyReLUs and ReLUs to inplace=False, I could avoid it. But, this get-around seems counterintuitive to me. By setting inplace=False aren't we increasing the memory consumption?

I also got another error now: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable). (prepare_for_backward at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:518)

I never encountered this error while training plainly with DDP (without using memcnn). I am running out of ideas as to how to solve this. Could you please give some details on how you were training your model? I could also post my code here for debugging, it's just that it is quite long and will take up a lot of space. Please let me know if you have any suggestions!

`

ibro45 commented 4 years ago

However, when I changed all LeakyReLUs and ReLUs to inplace=False, I could avoid it.

Oh wow, you're right @naga-karthik, somehow I overlooked that (mostly because the model I'm using has PReLUs, which don't have inplace option). Thanks!

But, this get-around seems counterintuitive to me. By setting inplace=False aren't we increasing the memory consumption?

That's true, I wasn't too bothered initially because I thought it was the case only in the invertible module, though I should've reported it anyways. But since actually all activation functions must have inplace=False, that really is a problem. @silvandeleemput any idea what might be the issue? Happy to give a minimal example!

Regarding the second problem, are you sure it is memcnn itself? When you were porting RevGAN to the newest mecnn, you might have left out something, something similar to what happened here.

I never encountered this error while training plainly with DDP (without using memcnn).

I assume you might have run cyclegan without a problem, but not revgan, which indeed uses memcnn, but since you ported it to the new memcnn, I think you probably missed adding something to the loss or something of that sort, and DDP is unforgiving about that (as mentioned in this comment, same issue as the link above). I had the same error when my network had a layer which I didn't actually use in forward() - without DDP Pytorch wasn't complaining, but with it, I got the same error as you did.

Sorry I previously said I had no problems running DDP, I completely forgot about the problem with inplace activation functions...

silvandeleemput commented 4 years ago

@ibro45 My first guess would be that this problem only applies if an inplace operation is used on either the input or the output node of a wrapped invertible module (the module wrapped with the memcnn.InvertibleModuleWrapper). Inplace operations can't be used on the input and output node because of the technicalities of performing the inverse during backpropagation. Essentially, the computation graph gets recreated, except for the input and output nodes, which are reused (sort of).

Note that while inplace can't be used during training with the InvertibleModuleWrapper this way, it should still save memory for the non-inplace nodes during the forward pass by only reconstructing the graph on the backward pass and saving memory that way. Furthermore, once the model is trained and the InvertibleModuleWrapper is no longer needed you can make these operations inplace again.

An example would be welcome for further inspection.