ACDSLab / MPPI-Generic

Templated C++/CUDA implementation of Model Predictive Path Integral Control (MPPI)
https://acdslab.github.io/mppi-generic-website/
BSD 2-Clause "Simplified" License
87 stars 8 forks source link

ERROR: Nan in state inside plant #5

Closed XinChen-stars closed 1 day ago

XinChen-stars commented 2 months ago

Hi, thank you for sharing the code!

Description

I tried following the MPC Usage: Listing 2 & Listing 3 the cartpole_plant_example running result: cartpole However,when trying to run the quadrotor example it returned the following error: quadrotor_error quadrotor plant code :

#pragma once
#include <mppi/core/base_plant.hpp>
#include <mppi/dynamics/quadrotor/quadrotor_dynamics.cuh>
template <class CONTROLLER_T>
class SimpleQuadrotorPlant : public BasePlant<CONTROLLER_T>
{
public:
  using control_array = typename QuadrotorDynamics::control_array;
  using state_array = typename QuadrotorDynamics::state_array;
  using output_array = typename QuadrotorDynamics::output_array;

  SimpleQuadrotorPlant(std::shared_ptr<CONTROLLER_T> controller, int hz, int optimization_stride)
    : BasePlant<CONTROLLER_T>(controller, hz, optimization_stride)
  {
    system_dynamics_ = std::make_shared<QuadrotorDynamics>();
  }

  void pubControl(const control_array& u)
  {
    state_array state_derivative;
    output_array dynamics_output;
    state_array prev_state = current_state_;
    float t = this->state_time_;
    float dt = this->controller_->getDt();
    system_dynamics_->step(prev_state, current_state_, state_derivative, u, dynamics_output, t, dt);
    current_time_ += dt;
  }

  void pubNominalState(const state_array& s)
  {
  }

  void pubFreeEnergyStatistics(MPPIFreeEnergyStatistics& fe_stats)
  {
  }

  int checkStatus()
  {
    return 0;
  }

  double getCurrentTime()
  {
    return current_time_;
  }

  double getPoseTime()
  {
    return this->state_time_;
  }

  double getAvgLoopTime() const
  {
    return this->avg_loop_time_ms_;
  }

  double getLastOptimizationTime() const
  {
    return this->optimization_duration_;
  }

  state_array current_state_ = state_array::Zero();

protected:
  std::shared_ptr<QuadrotorDynamics> system_dynamics_;
  double current_time_ = 0.0;
};

quadrotor example code :

#include <mppi/instantiations/quadrotor_mppi/quadrotor_mppi.cuh>
#include <quadrotor_plant.hpp>

const int NUM_TIMESTEPS = 100;
const int NUM_ROLLOUTS = 1024;
const int DYN_BLOCK_X = 32;
using DYN_T = QuadrotorDynamics;
const int DYN_BLOCK_Y = DYN_T::STATE_DIM;
using COST_T = QuadrotorQuadraticCost;
using FB_T = DDPFeedback<DYN_T, NUM_TIMESTEPS>;
using SAMPLING_T = mppi::sampling_distributions::GaussianDistribution<DYN_T::DYN_PARAMS_T>;
using CONTROLLER_T = VanillaMPPIController<DYN_T, COST_T, FB_T, NUM_TIMESTEPS, NUM_ROLLOUTS, SAMPLING_T>;
using CONTROLLER_PARAMS_T = CONTROLLER_T::TEMPLATED_PARAMS;

using PLANT_T = SimpleQuadrotorPlant<CONTROLLER_T>;

int main(int argc, char** argv)
{
  float dt = 0.02;
  DYN_T dynamics;                     // set up dynamics
  COST_T cost;                        // set up cost
  FB_T fb_controller(&dynamics, dt);  // set up feedback controller
  // set up sampling distribution
  SAMPLING_T sampler;
  auto sampler_params = sampler.getParams();
  std::fill(sampler_params.std_dev, sampler_params.std_dev + DYN_T::CONTROL_DIM, 10.0);
  sampler.setParams(sampler_params);

  // set up MPPI Controller
  CONTROLLER_PARAMS_T controller_params;
  controller_params.dt_ = dt;
  controller_params.lambda_ = 1.0;
  controller_params.dynamics_rollout_dim_ = dim3(DYN_BLOCK_X, DYN_BLOCK_Y, 1);
  controller_params.cost_rollout_dim_ = dim3(96, 1, 1);
  std::shared_ptr<CONTROLLER_T> controller =
      std::make_shared<CONTROLLER_T>(&dynamics, &cost, &fb_controller, &sampler, controller_params);

  // Create plant
  PLANT_T plant(controller, (1.0 / dt), 1);

  std::atomic<bool> alive(true);
  for (int t = 0; t < 10000; t++)
  {
    plant.updateState(plant.current_state_, (t + 1) * dt);
    plant.runControlIteration(&alive);
  }

  std::cout << "Avg Optimization time: " << plant.getAvgOptimizationTime() << " ms" << std::endl;
  std::cout << "Last Optimization time: " << plant.getLastOptimizationTime() << " ms" << std::endl;
  std::cout << "Avg Loop time: " << plant.getAvgLoopTime() << " ms" << std::endl;
  std::cout << "Avg Optimization Hz: " << 1.0 / (plant.getAvgOptimizationTime() * 1e-3) << " Hz" << std::endl;

  auto control_sequence = controller->getControlSeq();
  std::cout << "State: \n" << plant.current_state_.transpose() << std::endl;
  std::cout << "Control Sequence:\n" << control_sequence << std::endl;
  return 0;
}

Is there anything I am overlooking on coding an quadrotor example?

JasonGibson274 commented 2 months ago

Hello,

Yes the quadrotor dynamics have quaternions in them. By initializing it all to zero it will throw nans when doing update state here https://github.com/ACDSLab/MPPI-Generic/blob/main/include/mppi/dynamics/quadrotor/quadrotor_dynamics.cu#L182C24-L182C77.

Just use the method getZeroState https://github.com/ACDSLab/MPPI-Generic/blob/main/include/mppi/dynamics/quadrotor/quadrotor_dynamics.cu#L212 to get a valid initial state.

JasonGibson274 commented 2 months ago

I will make a note to move the examples to using the getZeroState method to prevent this issue in the future.

XinChen-stars commented 2 months ago

Thank you for your reply! I solved this problem by using the method getZeroState:

state_array current_state_ = system_dynamics_->getZeroState();

success