rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

How to work with control namely PID controller #240

Open Tomke997 opened 1 year ago

Tomke997 commented 1 year ago

Hi,

I am using this library for parameter estimation of a cruise control system with PID controller. My task is to represent my system as an ODE and predict control parameters namely Kp, Ki, Kd. Because of the nature of the PID controller, I have to use fixed-step solvers which is not ideal especially when using adjoint. Because of the removal of grid_points and eps, I am not able to force the solver to 'move forward'. I wanted to ask if there is a workaround for those issues. Thanks in advance!

Here is the code for my model:

class Cruise_model(torch.nn.Module):
    def __init__(self, solver, adjoint, parameters = None):
        # Call the super class initializer. 
        super(Cruise_model, self).__init__();

        # Set model parameters.
        self.Params = None
        self.Setpoint = 50.0/3.6
        self.Dt = 0.010
        self.Error_previous = 0.0
        self.Veh_drag_coeff = 0.33                             # Vehicle aerodynamic drag coefficient [N/m^2.s^2]
        self.Veh_height = 57.3 * 25.4/1000                     # Vehicle height [m]
        self.Veh_width = 69.9 * 25.4/1000                      # Vehicle width [m]
        self.Veh_drag_xA = self.Veh_width*self.Veh_height                # Effective front cross-sectional area [m^2]
        self.CdA = self.Veh_drag_coeff * self.Veh_drag_xA
        self.M = 2860/2.25 
        self.solver = solver
        self.adjoint = adjoint
        self.StartPoint = None

        self.x0 = None
        self.t = None

    def forward(self, t, x):              
        y, E = x
        dydt = torch.zeros_like(x)

        workaround_int = y + y.trunc().detach() - y.detach()
        error = self.Setpoint - workaround_int  
        inte = E 

        pid = self.Params[0]*error + self.Params[1]*inte + self.Params[2]*(error - self.Error_previous)/self.Dt
        F_ego = pid

        D_ego = self.CdA*(y ** 2)

        if y > 0:
            F_net = F_ego - D_ego
        else:
            F_net = F_ego + D_ego

        dydt[0] = F_net/self.M
        dydt[1] = error

        self.Error_previous = error
        return dydt

    def solve_ode(self, sp, t, params = None): # used for some optimization methods
        if params is not None:
            self.Params = params
        self.StartPoint = sp
        self.x0 = torch.tensor([self.StartPoint, (self.Setpoint - int(self.StartPoint))*self.Dt/2])
        self.t = t
        self.Error_previous = 0.0

        if self.adjoint:
            return odeint_adjoint(self.forward, self.x0, self.t, method = self.solver, options=dict(step_size=self.Dt), adjoint_params=list(self.parameters()) + [self.Params]) 
        else:
            return odeint(self.forward, self.x0, self.t, method = self.solver, options=dict(step_size=self.Dt))