Closed lighShanghaitech closed 1 year ago
Hi @lighShanghaitech, thanks for using MemCNN. I think your approach is reasonable. Just out of curiosity, why do you need n=3 to reset and not n=2?
As an alternative approach, you could maybe try to store the random-state for the dropout layers instead of the full dropout masks and regenerate the masks on the fly. That approach should save significantly reduce memory overhead. You could test first with InvertibleModuleWrapper(..., preserve_rng_state=True)
and see if that works for you.
I tried it. It seems to be not working. I set it as 3 before we have one for forward, one for inverse, and one for temporary forward.
@lighShanghaitech That's interesting, could you maybe provide some example code to reproduce the problem? Then I can have a look.
I am closing this due to inactivity.
Description
Hi @silvandeleemput. I am using Memcnn for blocks with dropout layer inside. I find the inverse input is not the same as the original input. To walk around this, I added a counter (n=3 to reset) into dropout and saved the dropout masks for every layers. However, it comes with the cost to save the dropout masks which is depended on the number of layers. Is there a more elegant to solve this issue?