Open Wang-Tianyu opened 3 years ago
Hey, thanks for using torchdyn
.
What you ask is not currently supported with the DataControl
layer, but can be done pretty easily in two ways:
nn.Module
save r
in self.r
and use it in a forward that calls (s, z)
only.u
we consider is DataControl, which triggers the above to assign the current input to module.u
. Depending on how general you'd prefer your implementation to be, you can simply modify DataControl
class CustomDataControl(nn.Module):
"""Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc
"""
def __init__(self, r):
super().__init__()
self.u = None
self.r = r
def forward(self, x):
return torch.cat([x, self.r], 1).to(x)
which is slightly hacky but works, or you can alter the logic in _prep_integration
to allow for custom assignments to module.u
.
Hey, thanks for your quick reply.
I managed to make it work following your suggestion!
One remaining question is, in this case, does the general condition r get any gradient? I am interested in building conditional CNF like this work https://github.com/stevenygd/PointFlow Do you think that is doable with this library?
Glad to hear it worked. Depending on your implementation, it should work just fine and get gradients. It is certainly possible to build PointFlows with torchdyn
:)
Hi, thanks for the great library.
I am interested in the ODE of the form dz(s)/ds = f_theta (s, r, z(s)) where r can be a vector independent of input x. Does the current implementation support this feature?