silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Droupout issue #56

Closed lighShanghaitech closed 1 year ago

lighShanghaitech commented 4 years ago

Description

Hi @silvandeleemput. I am using Memcnn for blocks with dropout layer inside. I find the inverse input is not the same as the original input. To walk around this, I added a counter (n=3 to reset) into dropout and saved the dropout masks for every layers. However, it comes with the cost to save the dropout masks which is depended on the number of layers. Is there a more elegant to solve this issue?

silvandeleemput commented 4 years ago

Hi @lighShanghaitech, thanks for using MemCNN. I think your approach is reasonable. Just out of curiosity, why do you need n=3 to reset and not n=2?

As an alternative approach, you could maybe try to store the random-state for the dropout layers instead of the full dropout masks and regenerate the masks on the fly. That approach should save significantly reduce memory overhead. You could test first with InvertibleModuleWrapper(..., preserve_rng_state=True) and see if that works for you.

lighShanghaitech commented 3 years ago

I tried it. It seems to be not working. I set it as 3 before we have one for forward, one for inverse, and one for temporary forward.

silvandeleemput commented 3 years ago

@lighShanghaitech That's interesting, could you maybe provide some example code to reproduce the problem? Then I can have a look.

silvandeleemput commented 1 year ago

I am closing this due to inactivity.