silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

how to joint several reversble reversible additive blocks together #72

Open LoveLiveSun opened 2 years ago

LoveLiveSun commented 2 years ago

Hello, this is a very great job, but i have some confusions, i want implement this equation: (x1, x2) = x y1 = x1 + Fm(x2) y2 = x2 + Gm(y1) y = (y1, y2)

and joint this again: (x3, x4) = y y3 = x3 + Fm(x4) y4 = x4 + Gm(y3) yy = (y3, y4)

and joint more of this,and how can i do this,can u tall me how to finish this work

silvandeleemput commented 2 years ago

Hi, have a look at the memcnn.AdditiveCoupling class, which is essentially what you are trying to use. Now you can implement your example in the following manner:

import memcnn

Fm = # TODO pick something nice
Gm = # TODO pick something nice
Fm2 = # TODO pick something nice 
Gm2 = # TODO pick something nice

network = torch.nn.Sequential(
  memcnn.AdditiveCoupling(Fm=Fm, Gm=Gm),
  memcnn.AdditiveCoupling(Fm=Fm2, Gm=Gm2),
)

x = # TODO define data that can be split into (x1, x2) along the channels

yy = network.forward(x)

For example values for Fm, Gm, Fm2, and Gm2 have a look at the example from the readme. Finally, if you want to have memory savings you can wrap the additive couplings inside a memcnn.InvertibleModuleWrapper, e.g.:

network = torch.nn.Sequential(
  memcnn.InvertibleModuleWrapper(fn=memcnn.AdditiveCoupling(Fm=Fm, Gm=Gm)),
  memcnn.InvertibleModuleWrapper(fn=memcnn.AdditiveCoupling(Fm=Fm2, Gm=Gm2)),
)

I hope that this explanation helps.