silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Backing over reversible network twice #36

Closed lucidrains closed 4 years ago

lucidrains commented 4 years ago

Hi Silvan!

I am currently weighing different implementations of reversible nets. They have suddenly gained new relevance with a new paper applying this to language models. I am currently hitting a wall where, in a more complex architecture, gradients are being backed through a reversible network twice, and because the backward is being handled manually, the graph needs to be retained for the second time. I am currently using RevTorch, but the author has little time to investigate this issue. Before testing out your library, do you know your work would solve this issue?

Thank you! Phil

lucidrains commented 4 years ago

Nevermind, the author of RevTorch and I went for a hacky solution. Thanks anyways

silvandeleemput commented 4 years ago

@lucidrains Good to see that you found a solution to your problem!

Just for your information, I believe that MemCNN is able to handle such cases as well. Internal states can be retained using the InvertibleModuleWrapper by setting the keep_input attributes to True on the first pass and set it to False on the second pass. Or alternatively, if the architecture allows it, you could share the weights and create two InvertibleModuleWrapper classes.