Open LoveLiveSun opened 2 years ago
Hi, have a look at the memcnn.AdditiveCoupling class, which is essentially what you are trying to use. Now you can implement your example in the following manner:
import memcnn
Fm = # TODO pick something nice
Gm = # TODO pick something nice
Fm2 = # TODO pick something nice
Gm2 = # TODO pick something nice
network = torch.nn.Sequential(
memcnn.AdditiveCoupling(Fm=Fm, Gm=Gm),
memcnn.AdditiveCoupling(Fm=Fm2, Gm=Gm2),
)
x = # TODO define data that can be split into (x1, x2) along the channels
yy = network.forward(x)
For example values for Fm, Gm, Fm2, and Gm2 have a look at the example from the readme.
Finally, if you want to have memory savings you can wrap the additive couplings inside a memcnn.InvertibleModuleWrapper
, e.g.:
network = torch.nn.Sequential(
memcnn.InvertibleModuleWrapper(fn=memcnn.AdditiveCoupling(Fm=Fm, Gm=Gm)),
memcnn.InvertibleModuleWrapper(fn=memcnn.AdditiveCoupling(Fm=Fm2, Gm=Gm2)),
)
I hope that this explanation helps.
Hello, this is a very great job, but i have some confusions, i want implement this equation: (x1, x2) = x y1 = x1 + Fm(x2) y2 = x2 + Gm(y1) y = (y1, y2)
and joint this again: (x3, x4) = y y3 = x3 + Fm(x4) y4 = x4 + Gm(y3) yy = (y3, y4)
and joint more of this,and how can i do this,can u tall me how to finish this work