DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.35k stars 125 forks source link

data-controlled ODE with general control signal #121

Open Wang-Tianyu opened 2 years ago

Wang-Tianyu commented 2 years ago

Hi, thanks for the great library.

I am interested in the ODE of the form dz(s)/ds = f_theta (s, r, z(s)) where r can be a vector independent of input x. Does the current implementation support this feature?

Zymrael commented 2 years ago

Hey, thanks for using torchdyn.

What you ask is not currently supported with the DataControl layer, but can be done pretty easily in two ways:

class CustomDataControl(nn.Module):
    """Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc
    """
    def __init__(self, r):
        super().__init__()
        self.u = None
        self.r = r

    def forward(self, x):
        return torch.cat([x, self.r], 1).to(x)

which is slightly hacky but works, or you can alter the logic in _prep_integration to allow for custom assignments to module.u.

Wang-Tianyu commented 2 years ago

Hey, thanks for your quick reply.

I managed to make it work following your suggestion!

One remaining question is, in this case, does the general condition r get any gradient? I am interested in building conditional CNF like this work https://github.com/stevenygd/PointFlow Do you think that is doable with this library?

Zymrael commented 2 years ago

Glad to hear it worked. Depending on your implementation, it should work just fine and get gradients. It is certainly possible to build PointFlows with torchdyn :)