Tim-Salzmann / l4casadi

Use PyTorch Models with CasADi for data-driven optimization or learning-based optimal control. Supports Acados.
MIT License
379 stars 29 forks source link

Support for the time-varying neural network model constraints modelling #48

Closed xueminchi closed 1 month ago

xueminchi commented 2 months ago

Hi @Tim-Salzmann ,

I'm trying to model an NN constraint while the NN constraint is time-varying, a static example will be like the provided Nerf optimization example. There are some parameters related to the output of the NN model, more than the state of the robot.

The current code I used is learned_cons = l4c.L4CasADi(NNModel(), batched=False, scripting=True). I try to update the parameters within the class public properties of NNModel() so that when learned_cons is called, it can return correct values so as to model the time-varying constraint. However, after I update the parameters of NNModel, it doesn't affect the evaluation of the constraint.

Then I put the update function into the forward() function of the neural network. I found that this function is only called several times, and the model will be updated within these calls but later forward() is not used.

The neural network model will be like

class NNModel(nn.Module):
    def __init__(self):
        super(NNModel, self).__init__()
        (parameters used to update)
    def forward(self, x):
        self.updateParameters()
        (other computations)
        return values 
    def updateParameters(self):
        (computations)

Is there a way to resolve this issue? I don't know if this is the correct way to model the time-varying constraint, please let me know if I was wrong.

best, Xuemin

Tim-Salzmann commented 2 months ago

Hi,

I am not entirely following. What about adding time to the state and passing it to the model explicitly?

Best Tim

xueminchi commented 2 months ago

Hi,

I am not entirely following. What about adding time to the state and passing it to the model explicitly?

Best Tim

Hi, thanks for the advice, and let me clarify the question with an example. For a given SDF circle model

class sdfCircle(nn.Module):
    def__init__(self):
        super(sdfCircle, self).__init__();
        self.center = None;
        self.radius = None;
    def forward(self, x):
        dist = torch.norm(x - self.center) - self.radius;
        return dist.reshape(-1,1)

To model a moving circle SDF, the center needs to update according to the time and the velocity, like x=x+dt * v. Could you provide some insights about how to implement this with l4casadi? Thanks.

Best, Xuemin

Tim-Salzmann commented 2 months ago

How about something like this

def forward(self, inp):
   x = inp[0]
   t = inp[1]
   center = self.center_function(t)
   radius = self.radius_function(t)
   dist = torch.norm(x - center) - radius;
   return dist.reshape(-1,1)
xueminchi commented 2 months ago

How about something like this

def forward(self, inp):
   x = inp[0]
   t = inp[1]
   center = self.center_function(t)
   radius = self.radius_function(t)
   dist = torch.norm(x - center) - radius;
   return dist.reshape(-1,1)

Thank you so much for the quick feedback. I have tried this. I found that the way to make this forward function work correctly is to set mutable = True and update the model online. If I don't turn on the mutable, even though I put the time into the forward function, it still doesn't work. However, the update operation consumes a lot of time. Could you explain a bit about this? Thanks!

Tim-Salzmann commented 2 months ago

even though I put the time into the forward function, it still doesn't work.

Can you explain exactly what does not work? If you can calculate the center and radius given the time then I am not sure why this would not work..

Best Tim

xueminchi commented 2 months ago

even though I put the time into the forward function, it still doesn't work.

Can you explain exactly what does not work? If you can calculate the center and radius given the time then I am not sure why this would not work..

Best Tim

I added an output to the forward function to check if the circle updates correctly. print(center), but after the code started running, it only printed 5 times. Then there was no print from the forward function and the circle center also stopped updating by checking the distance while the solver kept running. Is this normal?

def forward(self, inp):
   x = inp[0]
   t = inp[1]
   center = self.center_function(t)
   radius = self.radius_function(t)
   dist = torch.norm(x - center) - radius;
   print('center is : ', center)
   return dist.reshape(-1,1)
Tim-Salzmann commented 2 months ago

It is expected that the print out is not happening when the optimization is running. This does not mean it does not work.

The model is exported and compiled to C / C++ where the Python print statement is not included.

Tim-Salzmann commented 2 months ago

Let me know if you have any more open questions. Otherwise feel free to close this issue.