thibmonsel / torchdde

Neural network compatible DDEs
https://thibmonsel.github.io/torchdde/
Apache License 2.0
9 stars 0 forks source link

Wrong history function domain for DDE adjoint #34

Closed ReHoss closed 2 months ago

ReHoss commented 2 months ago

The line 108 modified in https://github.com/thibmonsel/torchdde/commit/a2b58e470b350ccadc207839acb07bd67886daf2 causes a bug.

https://github.com/thibmonsel/torchdde/blob/a659f8c62e9274a86e6bdbad722e4f2e3666f265/torchdde/adjoint_dde.py#L107-L114

In the case max(ctx.func.delays) < 0.5, the list [ctx.t1, 2 * max(ctx.func.delays) * ctx.t1] is not increasing, which causes the TorchLinearInterpolator throws a ValueError exception.

In the below MWE case, add_t == tensor([0.9000, 0.5400]) which is not increasing. IMO, 2 * max(ctx.func.delays) * ctx.t1 should be replaced with the correct analytical upperbound (domain of the history function for the reverse equation), something like $[t_1, t_1 + \max_i \tau_i]$.

Error trace

Traceback (most recent call last): File "/home/hosseinkhan/Documents/work/phd/git_repositories/control_dde/venv/venv_control_dde/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/hosseinkhan/Documents/work/phd/git_repositories/control_dde/venv/venv_control_dde/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply return user_fn(self, *args) File "/home/hosseinkhan/Documents/work/phd/git_repositories/control_dde/venv/venv_control_dde/lib/python3.10/site-packages/torchdde/adjoint_dde.py", line 111, in backward adjoint_interpolator = TorchLinearInterpolator( File "/home/hosseinkhan/Documents/work/phd/git_repositories/control_dde/venv/venv_control_dde/lib/python3.10/site-packages/torchdde/global_interpolation/linear_interpolation.py", line 30, in init raise ValueError("ts must be monotonically increasing.") ValueError: ts must be monotonically increasing.

Minimal working examle
# from typing import Any, Sequence
from typing import Any, Sequence

import torch
import torch.nn as nn
import torchdde

class MLP(nn.Module):
    def __init__(
        self,
        in_size: int,
        out_size: int,
        width_size: int,
        depth: int,
        delays: torch.Tensor | None = None,
        residual: bool = False,
        dropout: float = 0.0,
        activation: nn.Module = nn.ReLU(inplace=False),
        final_activation: nn.Module = nn.Identity(),
    ):
        super().__init__()
        self.delays = delays
        self.layers = nn.ModuleList(
            [nn.Linear(in_size, width_size)]
            + [nn.Linear(width_size, width_size) for _ in range(depth - 1)]
            + [nn.Linear(width_size, out_size)]
        )

        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        self.final_activation = final_activation
        self.residual = residual

    def forward(
        self,
        t: float,
        z: torch.Tensor,
        args: Any = None,
        history: Sequence[torch.Tensor] | None = None,
    ) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            z = self.dropout(z)
            z = layer(z)
            if i != (len(self.layers) - 1):
                z = self.activation(z)
        return self.final_activation(z)

def main():
    tensor_time_points = torch.tensor([0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000])
    tensor_time_points_history = torch.tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000])
    tensor_trajectory_history = torch.tensor(
        [
            [[-2.3789], [-2.1936], [-1.9956], [-1.7516], [-1.6767]],
            [[0.6427], [0.7463], [0.8204], [0.9334], [1.0079]],
            [[2.2204], [1.9631], [1.8347], [1.7059], [1.4606]],
            [[-4.5422], [-4.0878], [-3.7671], [-3.3287], [-3.0112]],
        ]
    )
    delays = torch.nn.Parameter(torch.tensor([0.3]))
    history_interpolator = torchdde.TorchLinearInterpolator(
        tensor_time_points_history, tensor_trajectory_history
    )
    solver = torchdde.RK4()
    args = None
    controller = torchdde.step_size_controller.constant.ConstantStepSizeController()
    dt0 = torch.tensor(0.1)
    bool_discretize_then_optimize = False

    def history_func(t: torch.Tensor) -> torch.Tensor:
        return history_interpolator(t)

    # Fix the initial condition which is a function of time
    _y0 = history_func

    model = MLP(
        in_size=tensor_trajectory_history.shape[-1],
        out_size=tensor_trajectory_history.shape[-1],
        width_size=32,
        depth=2,
        delays=delays,
    )

    tensor_trajectory_hat = torchdde.integrate(
        func=model,
        solver=solver,
        t0=tensor_time_points[0],
        t1=tensor_time_points[-1],
        ts=tensor_time_points,
        y0=_y0,  # We give the initial condition
        args=args,
        stepsize_controller=controller,
        dt0=dt0,
        delays=delays,
        discretize_then_optimize=bool_discretize_then_optimize,
    )

    loss_func = torch.nn.MSELoss()
    loss = loss_func(tensor_trajectory_hat, torch.zeros_like(tensor_trajectory_hat))
    loss.backward()
thibmonsel commented 2 months ago

Thanks for pointing that out, latest HEAD https://github.com/thibmonsel/torchdde/commits/master/ should fix this !

thibmonsel commented 2 months ago

Closing this issue.