Tim-Salzmann / l4casadi

Use PyTorch Models with CasADi for data-driven optimization or learning-based optimal control. Supports Acados.
MIT License
370 stars 28 forks source link

Multiple inputs in the forward call #28

Closed baryn closed 7 months ago

baryn commented 7 months ago

Hi Tim,

I have been working with l4casadi for a couple of months and now I am a bit stuck with the case when my network uses multiple inputs at the forward call. FYI I am implementing an NN-based MPC with acados.

If I am correct l4casadi expects a single input argument as in the acados.py example: res_model = self.learned_dyn(x) In my case I would like to call it with: res_model = self.learned_dyn(a,s) (action and state) The thing is that in my network's forward method, I wish to modify the inputs before passing it through the network, like this (scaling):

def forward(self, a, s):

        # scale
        a = (a - self.action_mean) / self.action_std
        # scale
        s = (s - self.state_mean) / self.state_std

        # run through network
        x = torch.cat([a, s], dim=1)
        x = self.linear(x)
        x = self.hidden(x)
        x = self.out(x)

        # scale
        x = (x * self.output_std) + self.output_mean
        return x

It works with PyTorch but not when wrapped within l4casadi. I get the error:

L4CasADi.forward() takes 2 positional arguments but 3 were given

with f_expl = self.learned_dyn(self.inputs,self.state)

Thus, considering a single argument forward method, I have tried to extract the data:

    def forward(self, s):

        # scale
        action = s[0][0:self.action_size]
        state = s[0][self.action_size:]
        a = (action - self.action_mean) / self.action_std
        # scale
        s = (state - self.state_mean) / self.state_std

        # run through network
        x = torch.cat([a, s])
        x = self.linear(x)
        x = self.hidden(x)
        x = self.out(x)

        # scale
        x = (x * self.output_std) + self.output_mean

        return x

It is correct when called with PyTorch but with l4casadi I get the following error:

L4CasADi requires the model output to be a matrix (2 dimensions) but has 1 dimensions. For models which expects a batch dimension, the output should be a matrix of [1, d]

Regards, Bryan

Tim-Salzmann commented 7 months ago

Hi Bryan,

Thanks for reaching out. Adding the capability for multiple inputs/outputs is something I want to implement at some time but have yet to have the time.

Having a single input and splitting it within PyTorch as you do would have been my recommendation for an easy fix.

Is it related to the model_expects_batch_dim parameter ? I am not sure I got the meaning of this parameter.

Most likely, yes!

Here is what the README says about this:

If your PyTorch model expects a batch dimension as first dimension (which most models do) you should pass model_expects_batch_dim=True to the L4CasADi constructor. The MX input to the L4CasADi component is then expected to be a vector of shape [X, 1]. L4CasADi will add a batch dimension of 1 automatically such that the input to the underlying PyTorch model is of shape [1, X].

Assuming you set model_expects_batch_dim=True, your PyTorch model will receive an input vector from L4CasADI of shape [B, X] with batch size B=1 and X being the input dimensions (dimensions of s + dimensions of a in your case). Similarly, L4CasADi expects the output of your PyTorch model to be [B, Y] with B=1. Hence, you get the error message that your output is missing the batch dimension. Please ensure that your PyTorch model is fulfilling these requirements.

Should you still have problems, please provide a minimal example that is not working.

Best Tim

baryn commented 7 months ago

Tim,

Thanks for you very much for your answer. I understand, so it is ok to split the data within PyTorch. I'll double check the dimensions.

Cheers, Bryan

Tim-Salzmann commented 7 months ago

Hi Bryan,

just wondering if this is resolved?

Thanks Tim

baryn commented 7 months ago

Hi Tim,

I haven't got the time to do a final check yet. But I suppose I can extract the data through PyTorch like you said, with the proper batch dimensions.

Thanks, Bryan

Tim-Salzmann commented 7 months ago

Feel free to re-open in case this should not work as expected!

Thanks

baryn commented 6 months ago

Hi Tim,

To bring the final fix to this issue, as expected I needed to reshape properly the intermediate inputs like so:

def forward(self, s):

        # scale
        action = s[:,0:self.action_size]
        state = s[:,self.action_size:]
        a = (action - self.action_mean) / self.action_std
        # scale
        s = (state - self.state_mean) / self.state_std

        # run through network
        x = torch.cat([a, s], dim = 1)
        x = self.linear(x)
        x = self.hidden(x)
        x = self.out(x)

        # scale
        x = (x * self.output_std) + self.output_mean

        return x

Now it runs well within L4Casadi.

Regards, Bryan