google / trajax

Apache License 2.0
186 stars 23 forks source link

ILQR optimizer doesn't support 1D scalar dynamical systems #7

Open Nusha97 opened 1 year ago

Nusha97 commented 1 year ago

When trying to run a 1D quadratic control affine nonlinear system of the form as shown below, the ILQR implementation is unable to handle scalar valued systems and results in a dimensionality mismatch error. Please find error and code below.

image

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os

import jax
from jax import device_put
from jax import vmap
from jax.config import config
import jax.numpy as np
import numpy as onp

from trajax import optimizers
from trajax.integrators import euler
from trajax.integrators import rk4
import matplotlib.pyplot as plt

def quadratic_nonlinear(x, u, t, params=(5, 10)):
    """
    Simple quadratic nonlinear system where we introduce reference trajectory as input
    :param x: 1D scalar state
    :param u: 1D input
    :param params: Kp, Kd gains for PD control law
    :return xdot: 1D array of shape 1
    """
    del t
    Kp, Kd = params
    r = np.squeeze(u)
    # xdot = (x ** 2 + Kp * (x - r) - Kd * rdot)/(1 - Kd)
    xdot = x ** 2 + Kp * (x - r)
    return np.array([xdot])

class ILQR_test():
    """
    Testing ILQR implementation in trajax for simple nonlinear systems
    """
    def __init__(self):
        pass

    def discretize(self, type='euler', dynamics=None):
        if dynamics is not None:
            self.dynamics = dynamics

        self.dynamics = euler(self.dynamics, dt=0.01)
        if type != 'euler':
            self.dynamics = rk4(self.dynamics, dt=0.01)

    def testQuadNonLinear(self, maxiter):
        """
        Calling ilqr on quadratic nonlinear system with input as reference trajectory
        :param maxiter: maximum number of iterations to take in ilqr
        :return: list of ilqr fn output
        """
        horizon = 100
        dynamics = rk4(quadratic_nonlinear, dt=0.01)

        true_params = (100.0, 10.0, 1.0)

        def cost(params, state, action, t):
            final_weight, stage_weight, action_weight = params

            state_err = state - action
            state_cost = stage_weight * (state_err ** 2 + action ** 2)
            # action_cost = action_weight * np.squeeze(action) ** 2
            return np.where(t == horizon, final_weight * state_cost,
                            state_cost)

        x0 = np.array([-0.9])
        U0 = np.zeros((horizon, 1))
        X, U, obj, grad, adj, lqr_val, total_iter = optimizers.ilqr(
            functools.partial(cost, true_params), dynamics, x0, U0, maxiter)
        return [X, U, obj, grad, adj, lqr_val, total_iter]

test = ILQR_test()
traj_cost = []
num_iter = [2, 30, 40, 50]

for i in num_iter:
    print(i)
    # traj = test.apply_ilqr(x0=onp.random.randn(2), U=onp.random.randn(2), maxiter=i, dynamics=rk4(quadratic_nonlinear, dt=0.01))
    # traj = test.testPendulumReadmeExample(maxiter=i)
    traj = test.testQuadNonLinear(maxiter=i)
    traj_cost.append(traj[2])

X = traj[0]
U = traj[1]

print(traj_cost)