sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.1k stars 100 forks source link

[BUG] incorrect parsing of actions in multiagent environment #296

Open gresavage opened 8 months ago

gresavage commented 8 months ago

Describe the bug

When calling Step in a multiagent environment with more than 2 players where all players' actions are processed simultaneously, the actions in the argument of the C++ API are "offset" from the actions sent in the Python API by 1 rather than "max_num_players".

I am trying to implement Multiagent Particle Environment (MPE) and have found that if I have num_envs environments and I send a [batch_size, ...] vector of actions to step in the Python api, then the Step function in the C++ API will receive an action which is offset by the number of previous calls to Step in that iteration.

For example, if I have 2 environments with 3 agents each and I send a flat vector of 6 actions [0, 1, 2, 3, 4, 5]:

The actions are not offset by the correct amount. The first value in the actions starts at index 1 instead of starting at max_num_players. The only way to correct this error is to know how many times Step has already been called, undo the index advancement done by ParseActions, and offset the actions by the correct amount. However this would be just a work around and cannot be implemented since env_index_ is a private variable so derived classes can't use it

To Reproduce

My working example consists of 6 files:

  1. core.h where the base classes are defined
  2. default_params.h where the default values for all environments are defined
  3. simple_env.h where the spec, Simple MPE base environment, and corresponding envpool async environment is defined
  4. simple_spread.h where the Simple Spread MPE environment and corresponding envpool async environment is defined
  5. mpe.cc the python bindings file
  6. BUILD my Bazel build file
  7. __init__.py the module file according to the EnvPool docs
  8. registration.py the registration file according to the EnvPool docs

This project uses Eigen as a C++ dependency. I have omitted the changes to setup.cfg and the envpool core files to make the custom environment accessible.

Below are the files for my C++ MPE environment

core.h


#ifndef ENVPOOL_MPE_CORE_H_
#define ENVPOOL_MPE_CORE_H_

#include <Eigen/Core>
#include <array>
#include <cmath>
#include <iostream>
#include <limits>
#include <random>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"
#include "envpool/mpe/default_params.h"

using namespace std;

namespace mpe {
const double pi = 3.14159265358979323846;

class MPEEntityState {
 public:
  Eigen::ArrayXd pos_;
  Eigen::ArrayXd vel_;
  int dim_p_;
  uniform_real_distribution<double> u_dist_;
  normal_distribution<double> n_dist_;
  Eigen::ArrayXd c_;
  int dim_c_;

 public:
  MPEEntityState(int dim_p = 2) : dim_p_(dim_p) {
    pos_.setZero(dim_p_);
    vel_.setZero(dim_p_);
  }
  virtual ~MPEEntityState() { delete this; }

  virtual void Reset(mt19937* gen) {
    pos_.setZero(dim_p_);
    vel_.setZero(dim_p_);
    NewRandomState(gen);
  }

  void RandomCommNoise(mt19937* gen, double noise_level) {}

  virtual void NewRandomState(mt19937* gen) {
    for (auto& p : pos_.reshaped()) {
      p = u_dist_(*gen) * 2.0 - 1.0;
    }
  }

  void SetPos(Eigen::ArrayXd* new_pos) { pos_ = *new_pos; }
  void SetVel(Eigen::ArrayXd* new_vel) { vel_ = *new_vel; }
};

class MPEAgentState : virtual public MPEEntityState {
 public:
  Eigen::ArrayXd c_;
  int dim_c_;

 public:
  MPEAgentState(int dim_p = 2, int dim_c = 0)
      : MPEEntityState(dim_p), dim_c_(dim_c) {
    c_.setZero(dim_c_);
  };

  void Reset(mt19937* gen) override {
    MPEEntityState::Reset(gen);
    c_.setZero();
  }

  void RandomCommNoise(mt19937* gen, double noise_level) {
    for (auto& c : c_.reshaped()) {
      c += u_dist_(*gen) * noise_level;
    }
  }
};
class MPEAction {
 public:
  int dim_p_;
  int dim_c_;
  Eigen::ArrayXd u_;
  Eigen::ArrayXd c_;

 public:
  MPEAction(int dim_p = 2, int dim_c = 0) : dim_p_(dim_p), dim_c_(dim_c) {
    u_.setZero(dim_p_);
    c_.setZero(dim_c_);
  };
  virtual ~MPEAction() { delete this; }

  void Reset() {
    u_.setZero();
    c_.setZero();
  }
};

class MPEEntity {
 public:
  string name_;
  double size_;
  bool moveable_;
  bool collide_;
  double density_;
  tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> color_;
  double max_speed_;
  double mass_;
  MPEEntityState* state_;
  double volume_;
  MPEAction* action_;
  bool silent_ = true;
  double u_noise_;
  double c_noise_;
  double u_range_;
  double accel_;

 public:
  MPEEntity(
      string name, double size = 0.05, double mass = 1.0, bool moveable = true,
      bool collide = false,
      tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> color = AGENT_COLOR,
      double max_speed = -1.0, int dim_p = 2)
      : name_(name),
        size_(size),
        moveable_(moveable),
        collide_(collide),
        color_(color),
        max_speed_(max_speed),
        mass_(mass),
        state_(new MPEEntityState(dim_p)) {
    volume_ = pow(pi, static_cast<double>(dim_p) / 2.0) /
              tgamma(1 + static_cast<double>(dim_p) / 2.0) *
              pow(size, static_cast<double>(dim_p));
    density_ = volume_ / mass_;
  }
  virtual ~MPEEntity() { delete state_; }

  virtual void Reset(mt19937* gen) { state_->Reset(gen); }
};

class MPEAgent : virtual public MPEEntity {
 public:
  bool silent_;
  double u_noise_;
  double c_noise_;
  double u_range_;
  MPEAgentState* state_;
  MPEAction* action_;
  double accel_;

 public:
  MPEAgent(
      string name, double size = 0.05, double mass = 1.0, bool moveable = true,
      bool collide = false, bool silent = false, double u_noise = 1.0,
      double c_noise = 1.0, double accel = 5.0,
      tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> color = AGENT_COLOR,
      double max_speed = -1.0, int dim_p = 2, int dim_c = 0)
      : MPEEntity(name, size, mass, moveable, collide, color, max_speed, dim_p),
        silent_(silent),
        u_noise_(u_noise),
        c_noise_(c_noise),
        accel_(accel) {
    state_ = new MPEAgentState(dim_p, dim_c);
    action_ = new MPEAction(dim_p, dim_c);
    u_range_ = 1.0;
  }
  virtual ~MPEAgent() {
    delete state_;
    delete action_;
  }
  void Reset(mt19937* gen) override {
    state_->Reset(gen);
    action_->Reset();
  }
};

class MPEWorld {
 public:
  int num_agents_;
  int num_landmarks_;
  vector<string> agent_names_;
  vector<string> landmark_names_;
  vector<string> entity_names_;
  vector<MPEAgent*> agents_;
  vector<MPEEntity*> landmarks_;
  int dim_p_;
  int dim_c_;
  double dt_;
  double damping_;
  double contact_force_;
  double contact_margin_;
  int num_entities_;
  uniform_real_distribution<double> u_dist_;
  normal_distribution<double> n_dist_;

 public:
  MPEWorld(int num_agents = 1, int num_landmarks = 1, int dim_p = 2,
           int dim_c = 0, double dt = 0.1, double damping = 0.25,
           double contact_force = 1e2, double contact_margin = 1e-3,
           vector<double> u_noise = vector<double>({0.0, 0.0}),
           vector<double> c_noise = vector<double>({1.0}),
           vector<bool> silent = vector<bool>({true}),
           vector<double> size = vector<double>({0.05, 0.05}),
           vector<double> mass = vector<double>({1.0, 1.0}),
           vector<bool> moveable = vector<bool>({true, true}),
           vector<bool> collide = vector<bool>({false, false}),
           vector<double> accel = vector<double>({5.0}),
           vector<tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t>> color =
               vector<tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t>>(
                   {AGENT_COLOR, OBS_COLOR}),
           vector<double> max_speed = vector<double>({-1.0, -1.0}))
      : num_agents_(num_agents),
        num_landmarks_(num_landmarks),
        dim_p_(dim_p),
        dim_c_(dim_c),
        dt_(dt),
        contact_force_(contact_force),
        contact_margin_(contact_margin),
        num_entities_(num_agents + num_landmarks) {
    EnsureSize(u_noise, num_agents_, false);
    EnsureSize(c_noise, num_agents_, false);
    EnsureSize(silent, num_agents_, false);
    EnsureSize(accel, num_agents_, false);
    EnsureSize(size, num_entities_, true);
    EnsureSize(mass, num_entities_, true);
    EnsureSize(moveable, num_entities_, true);
    EnsureSize(collide, num_entities_, true);
    EnsureSize(color, num_entities_, true);
    EnsureSize(max_speed, num_entities_, true);

    for (int i = 0; i < num_agents_; i++) {
      stringstream name;
      name << "agent_" << i;
      agents_.push_back(new MPEAgent(name.str(), size[i], mass[i], moveable[i],
                                     collide[i], accel[i], silent[i],
                                     u_noise[i], c_noise[i], color[i],
                                     max_speed[i], dim_p, dim_c));
    };

    for (int i = 0; i < num_landmarks_; i++) {
      stringstream name;
      int entity_id = num_agents_ + i;
      name << "landmark_" << i;
      landmarks_.push_back(new MPEEntity(
          name.str(), size[entity_id], mass[entity_id], moveable[entity_id],
          collide[entity_id], color[entity_id], max_speed[entity_id], dim_p));
    };
  }

  void ResetWorld(mt19937* gen) {
    for (const auto& agent : agents_) {
      agent->Reset(gen);
    }
    for (const auto& landmark : landmarks_) {
      landmark->Reset(gen);
    }
  }

  void StepWorld(mt19937* gen) {
    Eigen::ArrayXXd p_force;
    Eigen::ArrayXXd comms;
    p_force.setZero(num_agents_, dim_p_);
    comms.setZero(num_agents_, dim_c_);
    int i = 0;
    for (const auto& agent : agents_) {
      p_force(i, Eigen::all) = agent->action_->u_.reshaped();
      comms(i, Eigen::all) = agent->action_->c_.reshaped();
      ++i;
    }
    p_force = ApplyActionForce(gen, p_force);
    ApplyEnvForce(p_force);

    for (int entity_idx = 0; entity_idx < num_entities_; ++entity_idx) {
      auto p = p_force(entity_idx, Eigen::all);
      if (entity_idx < num_agents_) {
        if (!agents_[entity_idx]->moveable_) {
          continue;
        }
        IntegrateState(p, agents_[entity_idx]);
      } else {
        if (!landmarks_[entity_idx - num_agents_]->moveable_) {
          continue;
        }
        IntegrateState(p, landmarks_[entity_idx - num_agents_]);
      }
    }

    UpdateAgentState(gen);
  }

  template <typename Action>
  void DecodeDiscreteActions(Action& actions) {
    for (int i = 0; i < num_agents_; ++i) {
      Eigen::ArrayXd action;
      Eigen::ArrayXd comm;
      action.setZero(dim_p_);
      comm.setZero(dim_c_);
      int act = actions["action"_][i];
      int chnl = act / (2 * dim_p_ + 1);
      act %= (2 * dim_p_ + 1);
      if (agents_[i]->moveable_) {
        int action_idx = act <= 2 ? 0 : 1;
        double sent = (act % 2 == 0 ? -1.0 : 1.0) * (act != 0);
        action(action_idx, 0) = sent;
      }
      if (!agents_[i]->silent_) {
        comm(chnl, 0) = 1;
      }
      agents_[i]->action_->u_ = action * agents_[i]->accel_;
      agents_[i]->action_->c_ = comm;
    }
  }

  template <typename Action>
  void DecodeContinuousActions(Action& actions) {
    for (int i = 0; i < num_agents_; ++i) {
      Eigen::ArrayXd action;
      Eigen::ArrayXd comm;
      action.setZero(dim_p_);
      comm.setZero(dim_c_);
      vector<double> act = actions["action"_][i];
      for (auto& temp : act) {
        temp = (temp > 1.0) ? 1.0 : temp;
        temp = (temp < 0.0) ? 0.0 : temp;
        act.push_back(temp);
      }
      if (agents_[i]->moveable_) {
        action(0) += act[2] - act[1];
        action(1) += act[4] - act[3];
      }
      if (!agents_[i]->silent_) {
        for (int j = 0; j < dim_c_; ++j) {
          comm(j) = act[2 * dim_p_ + j + 1];
        }
      }
      agents_[i]->action_->u_ = action * agents_[i]->accel_;
      agents_[i]->action_->c_ = comm;
    }
  }

 protected:
  Eigen::ArrayXXd ApplyActionForce(mt19937* gen, Eigen::ArrayXXd& actions) {
    Eigen::ArrayXXd p_force;
    p_force.setZero(num_entities_, dim_p_);
    int agent_idx = 0;
    for (const auto& agent : agents_) {
      if (agent->moveable_) {
        Eigen::ArrayXd u_noise;
        u_noise.setZero(dim_p_);

        if (agent->u_noise_) {
          for (int i = 0; i < dim_p_; ++i) {
            u_noise(i, 0) = n_dist_(*gen) * agent->u_noise_;
          }
        }

        p_force(agent_idx, Eigen::all) =
            (actions(agent_idx, Eigen::all).reshaped() + u_noise).reshaped();
      }
      agent_idx += 1;
    }
    return p_force;
  };

  void ApplyEnvForce(Eigen::ArrayXXd& p_force) {
    for (int i = 0; i < num_entities_; ++i) {
      if (i < num_agents_) {
        for (int j = i + 1; j < num_entities_; ++j) {
          if (j < num_agents_) {
            auto [f_a, f_b] = GetCollisionForce(agents_[i], agents_[j]);
            p_force(i, Eigen::all) += f_a.reshaped();
            p_force(j, Eigen::all) += f_b.reshaped();
          } else {
            auto [f_a, f_b] =
                GetCollisionForce(agents_[i], landmarks_[j - num_agents_]);
            p_force(i, Eigen::all) += f_a.reshaped();
            p_force(j, Eigen::all) += f_b.reshaped();
          }
        }
      } else {
        for (int j = i + 1; j < num_entities_; ++j) {
          if (j < num_agents_) {
            auto [f_a, f_b] =
                GetCollisionForce(landmarks_[i - num_agents_], agents_[j]);
            p_force(i, Eigen::all) += f_a.reshaped();
            p_force(j, Eigen::all) += f_b.reshaped();
          } else {
            auto [f_a, f_b] = GetCollisionForce(landmarks_[i - num_agents_],
                                                landmarks_[j - num_agents_]);
            p_force(i, Eigen::all) += f_a.reshaped();
            p_force(j, Eigen::all) += f_b.reshaped();
          }
        }
      }
    }
  };

  template <typename U, typename V>
  tuple<Eigen::ArrayXd, Eigen::ArrayXd> GetCollisionForce(U& a, V& b) {
    Eigen::ArrayXd f_a;
    Eigen::ArrayXd f_b;
    f_a.setZero(dim_p_);
    f_b.setZero(dim_p_);
    if (a != b && a->collide_ && b->collide_) {
      Eigen::ArrayXd delta_pos = a->state_->pos_ - b->state_->pos_;
      double dist = sqrt(delta_pos.pow(2.0).sum());
      double dist_min = a->size_ + b->size_;
      if (dist < dist_min) {
        double penetration =
            log(1 + exp(-(dist - dist_min) / contact_margin_)) *
            contact_margin_;
        Eigen::ArrayXd force = contact_force_ * delta_pos * penetration / dist;

        if (a->moveable_) {
          f_a = force;
        }

        if (b->moveable_) {
          f_b = -force;
        }
      }
    }
    return make_tuple(f_a, f_b);
  };

  template <typename U, typename V = MPEEntity*>
  void IntegrateState(U& p_force, V& entity_ptr) {
    entity_ptr->state_->pos_ += entity_ptr->state_->vel_ * dt_;
    entity_ptr->state_->vel_ *= (1.0 - damping_);
    if (p_force.sum() != 0.0) {
      entity_ptr->state_->vel_ +=
          (dt_ * p_force / entity_ptr->mass_).reshaped(dim_p_, 1);
    }
    if (entity_ptr->max_speed_ > 0) {
      double speed = sqrt(entity_ptr->state_->vel_.pow(2.0).sum());
      if (speed >= entity_ptr->max_speed_) {
        entity_ptr->state_->vel_ *= entity_ptr->max_speed_ / speed;
      }
    }
  };

  void UpdateAgentState(mt19937* gen) {
    for (const auto& agent : agents_) {
      if (agent->silent_) {
        agent->state_->c_.setZero();
      } else {
        if (agent->c_noise_) {
          agent->state_->RandomCommNoise(gen, agent->c_noise_);
        }
      }
    }
  }

 private:
  template <typename U>
  void EnsureSize(vector<U>& object, int size, bool include_landmarks = true) {
    int fill_size = size - object.size();
    if (fill_size > 0) {
      if (include_landmarks && object.size() == 2) {
        U agent_fill_value = object[0];
        U landmark_fill_value = object[1];
        int agent_fill_size = num_agents_ - 1;
        int landmark_fill_size = num_landmarks_ - 1;

        for (int i = 0; i < agent_fill_size; i++) {
          object.insert(object.begin(), agent_fill_value);
        };
        for (int i = 0; i < landmark_fill_size; i++) {
          object.push_back(landmark_fill_value);
        };
      } else if (include_landmarks && object.size() == 1) {
        U agent_fill_value = object[0];
        U landmark_fill_value = object[0];
        int agent_fill_size = num_agents_ - 1;
        int landmark_fill_size = num_landmarks_;

        for (int i = 0; i < agent_fill_size; i++) {
          object.insert(object.begin(), agent_fill_value);
        };
        for (int i = 0; i < landmark_fill_size; i++) {
          object.push_back(landmark_fill_value);
        };
      } else if (include_landmarks) {
        U agent_fill_value = object[0];
        U landmark_fill_value = object[object.size() - 1];
        int agent_fill_size = fill_size - num_landmarks_ - 1;
        int landmark_fill_size = fill_size - num_agents_ - 1;

        for (int i = 0; i < agent_fill_size; i++) {
          object.insert(object.begin(), agent_fill_value);
        };
        for (int i = 0; i < landmark_fill_size; i++) {
          object.push_back(landmark_fill_value);
        };
      } else {
        U agent_fill_value = object[object.size() - 1];
        int agent_fill_size = fill_size;

        for (int i = 0; i < agent_fill_size; i++) {
          object.push_back(agent_fill_value);
        };
      }
    }
  };
};
}  // namespace mpe

#endif

scenario.h


#ifndef ENVPOOL_MPE_SCENARIO_H_
#define ENVPOOL_MPE_SCENARIO_H_

#include "envpool/core/env.h"
#include "envpool/mpe/core.h"
#include "envpool/mpe/default_params.h"

namespace mpe {
class MPEScenario {
 public:
  unique_ptr<MPEWorld> world_;

 public:
  MPEScenario() {}

  template <typename Spec>
  void MakeWorld(const Spec& spec) {}
  virtual void ResetWorld(mt19937* gen) = 0;
  virtual vector<double> Reward() = 0;
};
}  // namespace mpe
#endif

default_params.h

#ifndef ENVPOOL_MPE_DEFAULT_PARAMS_H_
#define ENVPOOL_MPE_DEFAULT_PARAMS_H_

#include <string>
#include <tuple>

using namespace std;

namespace mpe {
// Action types
string DISCRETE_ACT = "Discrete";
string CONTINUOUS_ACT = "Continuous";

// Environment
int MAX_STEPS = 25;
int DAMPING = 0.25;
double CONTACT_FORCE = 100.0;
double CONTACT_MARGIN = 1000.0;
double DT = 0.1;

// Colours
tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> AGENT_COLOR = {115, 243,
                                                                  115};
tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> ADVERSARY_COLOR = {243, 115,
                                                                      115};
tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t> OBS_COLOR = {64, 64, 64};
}  // namespace mpe

#endif

simple_env.h


#ifndef ENVPOOL_MPE_SIMPLE_ENV_H_
#define ENVPOOL_MPE_SIMPLE_ENV_H_

#include <Eigen/Core>
#include <array>
#include <limits>
#include <random>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"
#include "envpool/mpe/core.h"
#include "envpool/mpe/default_params.h"
#include "envpool/mpe/scenario.h"

using namespace std;

namespace mpe {

class SimpleEnvFns {
 public:
  static decltype(auto) DefaultConfig() {
    return MakeDict(
        "num_agents"_.Bind(1), "num_landmarks"_.Bind(1),
        "color"_.Bind(
            vector<tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t>>(
                {AGENT_COLOR, OBS_COLOR})),
        "dim_c"_.Bind(0), "dim_p"_.Bind(2), "dt"_.Bind(DT),
        "size"_.Bind(vector<double>({0.15, 0.2})),
        "moveable"_.Bind(vector<bool>({true, false})),
        "silent"_.Bind(vector<int>({1})),
        "collide"_.Bind(vector<bool>({false, false})),
        "mass"_.Bind(vector<double>({1.0, 1.0})),
        "accel"_.Bind(vector<double>({5.0})),
        "max_speed"_.Bind(vector<double>({-1.0, 0.0})),
        "u_noise"_.Bind(vector<double>({0.0})),
        "c_noise"_.Bind(vector<double>({0.0})), "damping"_.Bind(DAMPING),
        "contact_force"_.Bind(CONTACT_FORCE),
        "contact_margin"_.Bind(CONTACT_MARGIN));
  }

  template <typename Config>
  static decltype(auto) StateSpec(const Config& conf) {
    double inf = numeric_limits<double>::infinity();
    // x vel, y vel, x & y rel_pos for each other entity
    int obs_dim =
        conf["dim_p"_] * (1 + conf["num_landmarks"_] + conf["num_agents"_] - 1);
    vector<double> obs_lower(obs_dim, -inf), obs_upper(obs_dim, inf);
    return MakeDict(
        "obs"_.Bind(Spec<double>({-1, obs_dim}, {obs_lower, obs_upper})),
        "info:players.trunc"_.Bind(Spec<bool>({-1})),
        "info:players.term"_.Bind(Spec<bool>({-1})),
        "info:terminated"_.Bind(Spec<bool>({})));
  }
};

class SimpleContinuousEnvFns : public SimpleEnvFns {
 public:
  template <typename Config>
  static decltype(auto) ActionSpec(const Config& conf) {
    return MakeDict("action"_.Bind(Spec<double>(
        {-1, 2 * conf["dim_p"_]}, {0.0, 1.0})));
  }
};

class SimpleDiscreteEnvFns : public SimpleEnvFns {
 public:
  template <typename Config>
  static decltype(auto) ActionSpec(const Config& conf) {
    return MakeDict(
        "action"_.Bind(Spec<int>({-1}, {0, 2 * conf["dim_p"_]})));
  }
};

using SimpleContinuousEnvSpec = EnvSpec<SimpleContinuousEnvFns>;
using SimpleDiscreteEnvSpec = EnvSpec<SimpleDiscreteEnvFns>;

class SimpleMPEScenario : virtual public MPEScenario {
 public:
  SimpleMPEScenario() {}
  template <typename Spec>
  void MakeWorld(const Spec& spec) {
    vector<bool> silent = vector<bool>(
        spec.config["num_agents"_], true);
    vector<bool> moveable = vector<bool>(spec.config["num_agents"_], true);
    vector<bool> collide = vector<bool>(
        spec.config["num_agents"_] + spec.config["num_landmarks"_], false);
    for (int i = 0; i < spec.config["num_landmarks"_]; ++i) {
      moveable.push_back(false);
    }
    world_ = make_unique<MPEWorld>(
        spec.config["num_agents"_], spec.config["num_landmarks"_],
        spec.config["dim_p"_], spec.config["dim_c"_], spec.config["dt"_],
        spec.config["damping"_], spec.config["contact_force"_],
        spec.config["contact_margin"_], spec.config["u_noise"_],
        spec.config["c_noise"_], silent, spec.config["size"_],
        spec.config["mass"_], moveable, collide, spec.config["accel"_],
        spec.config["color"_], spec.config["max_speed"_]);
  };

  void ResetWorld(mt19937* gen) override { world_->ResetWorld(gen); }
  vector<double> Reward() override {
    vector<double> reward;
    for (auto& agent : world_->agents_) {
      double agent_reward = 0;
      for (auto& landmark : world_->landmarks_) {
        agent_reward +=
            sqrt((agent->state_->pos_ - landmark->state_->pos_).pow(2.0).sum());
      }
      reward.push_back(-agent_reward);
    }
    return reward;
  }

  vector<vector<double>> GetObs() {
    vector<vector<double>> all_agent_obs;
    fprintf(stderr, "getting obs\n");
    for (auto& agent : world_->agents_) {
      vector<double> agent_obs;
      for (auto& vel : agent->state_->vel_) {
        agent_obs.push_back(vel);
      }
      auto agent_pos = agent->state_->pos_;
      for (int entity_idx = 0; entity_idx < world_->num_entities_;
           ++entity_idx) {
        MPEEntity* entity_ptr;
        if (entity_idx < world_->num_agents_) {
          entity_ptr = world_->agents_[entity_idx];
        } else {
          entity_ptr =
              world_->landmarks_[entity_idx - world_->num_agents_];
        }
        auto diff = entity_ptr->state_->pos_ - agent_pos;
        for (auto d : diff.reshaped()) {
          agent_obs.push_back(d);
        }
      }
      all_agent_obs.push_back(agent_obs);
    }
    return all_agent_obs;
  }
};

class SimpleMPEEnv {
 public:
  bool continuous_;
  SimpleMPEScenario* scenario_;
  int max_episode_steps_;
  int elapsed_step_;

 public:
  SimpleMPEEnv(bool continuous = false, int max_steps = 25)
      : continuous_(continuous), max_episode_steps_(max_steps) {}

  template <typename Spec>
  void InitScenario(const Spec& spec) {
    scenario_ = new SimpleMPEScenario();
    scenario_->MakeWorld(spec);
  };
};

class SimpleContinuousEnv : public Env<SimpleContinuousEnvSpec>,
                            public SimpleMPEEnv {
 public:
  SimpleContinuousEnv(const Spec& spec, int env_id)
      : Env<SimpleContinuousEnvSpec>(spec, env_id),
        SimpleMPEEnv(true, spec.config["max_episode_steps"_]) {
    max_num_players_ = spec.config["num_agents"_];
    max_episode_steps_ = spec.config["max_episode_steps"_];
    InitScenario(spec);
  }

 public:
  bool IsDone() override { return elapsed_step_ >= max_episode_steps_; };
  void Reset() override {
    scenario_->ResetWorld(&gen_);
    elapsed_step_ = 0;
    WriteState();
  }
  void Step(const Action& action) override {
    ++elapsed_step_;
    scenario_->world_->DecodeContinuousActions(action);
    scenario_->world_->StepWorld(&gen_);
    WriteState();
  }

 protected:
  void WriteState() {
    State state = Allocate(max_num_players_);
    vector<vector<double>> obs = scenario_->GetObs();
    vector<double> rewards = scenario_->Reward();
    bool trunc = elapsed_step_ >= max_episode_steps_;

    for (int i = 0; i < max_num_players_; i++) {
      for (int j = 0; j < int(obs[i].size()); j++) {
        state["obs"_](i, j) = obs[i][j];
      };

      state["info:players.trunc"_](i) = trunc;
      state["info:players.term"_](i) = false;
      state["reward"_](i) = rewards[i];
    };
    state["trunc"_] = trunc;
    state["info:terminated"_] = false;
  }
};

class SimpleDiscreteEnv : public Env<SimpleDiscreteEnvSpec>,
                          public SimpleMPEEnv {
 public:
  SimpleDiscreteEnv(const Spec& spec, int env_id)
      : Env<SimpleDiscreteEnvSpec>(spec, env_id),
        SimpleMPEEnv(false, spec.config["max_episode_steps"_]) {
    max_num_players_ = spec.config["num_agents"_];
    max_episode_steps_ = spec.config["max_episode_steps"_];
    InitScenario(spec);
  }

 public:
  bool IsDone() override { return elapsed_step_ >= max_episode_steps_; };
  void Reset() override {
    scenario_->ResetWorld(&gen_);
    elapsed_step_ = 0;
    WriteState();
  }
  void Step(const Action& action) override {
    ++elapsed_step_;
    scenario_->world_->DecodeDiscreteActions(action);
    scenario_->world_->StepWorld(&gen_);
    WriteState();
  }

 protected:
  void WriteState() {
    State state = Allocate(max_num_players_);
    // vector<vector<double>> obs = scenario_->world_->GetObs();
    vector<vector<double>> obs = scenario_->GetObs();
    vector<double> rewards = scenario_->Reward();
    bool trunc = elapsed_step_ >= max_episode_steps_;

    for (int i = 0; i < max_num_players_; i++) {
      for (int j = 0; j < int(obs[i].size()); j++) {
        state["obs"_](i, j) = obs[i][j];
      };

      state["info:players.trunc"_](i) = trunc;
      state["info:players.term"_](i) = false;
      state["reward"_](i) = rewards[i];
    };
    state["trunc"_] = trunc;
    state["info:terminated"_] = false;
  }
};

using SimpleContinuousEnvPool = AsyncEnvPool<SimpleContinuousEnv>;
using SimpleDiscreteEnvPool = AsyncEnvPool<SimpleDiscreteEnv>;

}  // namespace mpe
#endif

simple_spread.h

#ifndef ENVPOOL_MPE_SIMPLE_SPREAD_H_
#define ENVPOOL_MPE_SIMPLE_SPREAD_H_

#include <limits>
#include <random>
#include <sstream>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"
#include "envpool/mpe/core.h"
#include "envpool/mpe/default_params.h"
#include "envpool/mpe/scenario.h"
#include "envpool/mpe/simple_env.h"

using namespace std;

namespace mpe {

class SimpleSpreadContinuousEnvFns : public SimpleContinuousEnvFns {
 public:
  static decltype(auto) DefaultConfig() {
    return MakeDict(
        "num_agents"_.Bind(3), "num_landmarks"_.Bind(3),
        "local_ratio"_.Bind(0.5),
        "color"_.Bind(
            vector<tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t>>(
                {AGENT_COLOR, OBS_COLOR})),
        "dim_c"_.Bind(2), "dim_p"_.Bind(2), "dt"_.Bind(DT),
        "size"_.Bind(vector<double>({0.15, 0.2})),
        "moveable"_.Bind(vector<bool>({true, false})),
        "silent"_.Bind(vector<int>({1})),
        "collide"_.Bind(vector<bool>({true, false})),
        "mass"_.Bind(vector<double>({1.0, 1.0})),
        "accel"_.Bind(vector<double>({5.0})),
        "max_speed"_.Bind(vector<double>({-1.0, 0.0})),
        "u_noise"_.Bind(vector<double>({0.0})),
        "c_noise"_.Bind(vector<double>({0.0})), "damping"_.Bind(DAMPING),
        "contact_force"_.Bind(CONTACT_FORCE),
        "contact_margin"_.Bind(CONTACT_MARGIN));
  }

  template <typename Config>
  static decltype(auto) StateSpec(const Config& conf) {
    double inf = numeric_limits<double>::infinity();
    // x vel, y vel, x_pos, y_pos, x & y rel_pos for each landmark, x & y
    // rel_pos for each other agent
    int obs_dim = conf["dim_p"_] *
                      (2 + conf["num_landmarks"_] + (conf["num_agents"_] - 1)) +
                  conf["dim_c"_];
    return MakeDict("obs"_.Bind(Spec<double>({-1, obs_dim}, {-inf, inf})),
                    "info:players.trunc"_.Bind(Spec<bool>({-1})),
                    "info:players.term"_.Bind(Spec<bool>({-1})),
                    "info:terminated"_.Bind(Spec<bool>({})));
  }
};

class SimpleSpreadDiscreteEnvFns : public SimpleDiscreteEnvFns {
 public:
  static decltype(auto) DefaultConfig() {
    return MakeDict(
        "num_agents"_.Bind(3), "num_landmarks"_.Bind(3),
        "local_ratio"_.Bind(0.5),
        "color"_.Bind(
            vector<tuple<uint_fast16_t, uint_fast16_t, uint_fast16_t>>(
                {AGENT_COLOR, OBS_COLOR})),
        "dim_c"_.Bind(2), "dim_p"_.Bind(2), "dt"_.Bind(DT),
        "size"_.Bind(vector<double>({0.15, 0.2})),
        "moveable"_.Bind(vector<bool>({true, false})),
        "silent"_.Bind(vector<int>({1})),
        "collide"_.Bind(vector<bool>({true, false})),
        "mass"_.Bind(vector<double>({1.0, 1.0})),
        "accel"_.Bind(vector<double>({5.0})),
        "max_speed"_.Bind(vector<double>({-1.0, 0.0})),
        "u_noise"_.Bind(vector<double>({0.0})),
        "c_noise"_.Bind(vector<double>({0.0})), "damping"_.Bind(DAMPING),
        "contact_force"_.Bind(CONTACT_FORCE),
        "contact_margin"_.Bind(CONTACT_MARGIN));
  }

  template <typename Config>
  static decltype(auto) StateSpec(const Config& conf) {
    double inf = numeric_limits<double>::infinity();
    int obs_dim = conf["dim_p"_] *
                      (2 + conf["num_landmarks"_] + (conf["num_agents"_] - 1)) +
                  conf["dim_c"_];
    return MakeDict("obs"_.Bind(Spec<double>({-1, obs_dim}, {-inf, inf})),
                    "info:players.trunc"_.Bind(Spec<bool>({-1})),
                    "info:players.term"_.Bind(Spec<bool>({-1})),
                    "info:terminated"_.Bind(Spec<bool>({})));
  }
};

using SimpleSpreadContinuousEnvSpec = EnvSpec<SimpleSpreadContinuousEnvFns>;
using SimpleSpreadDiscreteEnvSpec = EnvSpec<SimpleSpreadDiscreteEnvFns>;

class SimpleSpreadScenario : virtual public SimpleMPEScenario {
 public:
  double local_ratio_;

 public:
  SimpleSpreadScenario() {}
  template <typename Spec>
  void MakeWorld(const Spec& spec) {
    vector<double> size = vector<double>(spec.config["num_agents"_], 0.15);
    vector<bool> silent = vector<bool>(spec.config["num_agents"_], true);
    vector<bool> moveable = vector<bool>(spec.config["num_agents"_], true);
    vector<bool> collide = vector<bool>(spec.config["num_agents"_], true);
    for (int i = 0; i < spec.config["num_landmarks"_]; ++i) {
      size.push_back(0.05);
      moveable.push_back(false);
      collide.push_back(false);
    }
    local_ratio_ = spec.config["local_ratio"_];
    world_ = make_unique<MPEWorld>(
        spec.config["num_agents"_], spec.config["num_landmarks"_],
        spec.config["dim_p"_], spec.config["dim_c"_], spec.config["dt"_],
        spec.config["damping"_], spec.config["contact_force"_],
        spec.config["contact_margin"_], spec.config["u_noise"_],
        spec.config["c_noise"_], silent, size, spec.config["mass"_], moveable,
        collide, spec.config["accel"_], spec.config["color"_],
        spec.config["max_speed"_]);
  };

  void ResetWorld(mt19937* gen) override { this->world_->ResetWorld(gen); }

  vector<vector<double>> GetObs() {
    vector<vector<double>> all_agent_obs;
    int curr_agent = 0;
    for (const auto& agent : this->world_->agents_) {
      vector<double> agent_obs;
      for (const auto& vel : agent->state_->vel_.reshaped()) {
        agent_obs.push_back(vel);
      }
      for (int p_pos = 0; p_pos < this->world_->dim_p_; ++p_pos) {
        agent_obs.push_back(agent->state_->pos_.reshaped()[p_pos]);
      }
      for (const auto& landmark : this->world_->landmarks_) {
        Eigen::MatrixXd disp = landmark->state_->pos_ - agent->state_->pos_;
        for (const auto& d : disp.reshaped()) {
          agent_obs.push_back(d);
        }
      }

      // add rel pos to other agents
      int other_idx = 0;
      for (const auto& other : this->world_->agents_) {
        if (curr_agent == other_idx) {
          ++other_idx;
          continue;
        }
        Eigen::MatrixXd disp = other->state_->pos_ - agent->state_->pos_;
        for (const auto& d : disp.reshaped()) {
          agent_obs.push_back(d);
        }
        ++other_idx;
      }

      // add to all obs
      all_agent_obs.push_back(agent_obs);
      ++curr_agent;
    }
    return all_agent_obs;
  }

  vector<double> Reward() override {
    vector<double> reward;
    double global_reward = 0;
    double min_dist = numeric_limits<double>::infinity();

    for (const auto& landmark : this->world_->landmarks_) {
      for (const auto& agent : this->world_->agents_) {
        Eigen::ArrayXd delta_pos = landmark->state_->pos_ - agent->state_->pos_;
        double dist = sqrt(delta_pos.pow(2.0).sum());
        if (dist < min_dist) {
          min_dist = dist;
        }
      }
      if (min_dist < numeric_limits<double>::infinity()) {
        global_reward -= min_dist;
      }
    }

    for (const auto& agent : this->world_->agents_) {
      double agent_reward = 0;
      if (agent->collide_) {
        for (const auto& other : this->world_->agents_) {
          agent_reward -= local_ratio_ * (IsCollision(agent, other));
        }
      }
      agent_reward += global_reward * (1.0 - local_ratio_);
      reward.push_back(agent_reward);
    }
    return reward;
  }

  template <typename U>
  bool IsCollision(U& entity0, U& entity1) {
    if (entity0 != entity1 && entity0->collide_ && entity1->collide_) {
      Eigen::ArrayXd delta_pos = entity0->state_->pos_ - entity1->state_->pos_;
      double dist = sqrt(delta_pos.pow(2.0).sum());
      double dist_min = entity0->size_ + entity1->size_;
      return dist < dist_min;
    }
    return false;
  }
};

class SimpleSpreadEnv : virtual public SimpleMPEEnv {
 public:
  double local_ratio_;
  unique_ptr<SimpleSpreadScenario> scenario_;

 public:
  SimpleSpreadEnv(bool continuous = false, int max_steps = 25,
                  double local_ratio = 0.5)
      : SimpleMPEEnv(continuous, max_steps), local_ratio_(local_ratio) {}

  template <typename Spec>
  void InitScenario(const Spec& spec) {
    scenario_ = make_unique<SimpleSpreadScenario>();
    this->scenario_->MakeWorld(spec);
  }
};

class SimpleSpreadContinuousEnv : public Env<SimpleSpreadContinuousEnvSpec>,
                                  public SimpleSpreadEnv {
 public:
  SimpleSpreadContinuousEnv(const Spec& spec, int env_id)
      : Env<SimpleSpreadContinuousEnvSpec>(spec, env_id),
        SimpleSpreadEnv(true, spec.config["max_episode_steps"_],
                        spec.config["local_ratio"_]) {
    max_num_players_ = spec.config["num_agents"_];
    max_episode_steps_ = spec.config["max_episode_steps"_];
    SimpleSpreadEnv::InitScenario(spec);
  }
  bool IsDone() override { return elapsed_step_ >= max_episode_steps_; };
  void Reset() override {
    this->scenario_->ResetWorld(&gen_);
    elapsed_step_ = 0;
    WriteState();
  }
  void Step(const Action& action) override {
    ++elapsed_step_;
    this->scenario_->world_->DecodeContinuousActions(action);
    this->scenario_->world_->StepWorld(&gen_);
    WriteState();
  }

 protected:
  void WriteState() {
    State state = Allocate(max_num_players_);
    vector<vector<double>> obs = this->scenario_->GetObs();
    vector<double> rewards = this->scenario_->Reward();
    bool trunc = elapsed_step_ >= max_episode_steps_;

    for (int i = 0; i < max_num_players_; i++) {
      for (int j = 0; j < int(obs[i].size()); j++) {
        state["obs"_](i, j) = obs[i][j];
      }

      state["info:players.trunc"_](i) = trunc;
      state["info:players.term"_](i) = false;
      state["reward"_](i) = rewards[i];
    };
    state["trunc"_] = trunc;
    state["info:terminated"_] = false;
  }
};

class SimpleSpreadDiscreteEnv : public Env<SimpleSpreadDiscreteEnvSpec>,
                                public SimpleSpreadEnv {
 public:
  SimpleSpreadDiscreteEnv(const Spec& spec, int env_id)
      : Env<SimpleSpreadDiscreteEnvSpec>(spec, env_id),
        SimpleSpreadEnv(false, spec.config["max_episode_steps"_],
                        spec.config["local_ratio"_]) {
    max_num_players_ = spec.config["num_agents"_];
    max_episode_steps_ = spec.config["max_episode_steps"_];
    InitScenario(spec);
  }
  bool IsDone() override { return elapsed_step_ >= max_episode_steps_; };
  void Reset() override {
    this->scenario_->ResetWorld(&gen_);
    elapsed_step_ = 0;
    WriteState();
  }
  void Step(const Action& action) override {
    ++elapsed_step_;
    stringstream msg;
    msg << "env: " << env_id_ << endl;
    cout << msg.str();
    this->scenario_->world_->DecodeDiscreteActions(action);
    this->scenario_->world_->StepWorld(&gen_);
    WriteState();
  }

 protected:
  void WriteState() {
    State state = Allocate(max_num_players_);
    vector<vector<double>> obs = this->scenario_->GetObs();
    vector<double> rewards = this->scenario_->Reward();
    bool trunc = elapsed_step_ >= max_episode_steps_;

    for (int i = 0; i < max_num_players_; i++) {
      for (int j = 0; j < int(obs[i].size()); j++) {
        state["obs"_](i, j) = obs[i][j];
      };

      state["info:players.trunc"_](i) = trunc;
      state["info:players.term"_](i) = false;
      state["reward"_](i) = rewards[i];
    };
    state["trunc"_] = trunc;
    state["info:terminated"_] = false;
  }
};

using SimpleSpreadContinuousEnvPool = AsyncEnvPool<SimpleSpreadContinuousEnv>;
using SimpleSpreadDiscreteEnvPool = AsyncEnvPool<SimpleSpreadDiscreteEnv>;

}  // namespace mpe
#endif

mpe.cc

#include "envpool/core/py_envpool.h"
#include "envpool/mpe/simple_env.h"
#include "envpool/mpe/simple_spread.h"

using SimpleDiscreteEnvSpec = PyEnvSpec<mpe::SimpleDiscreteEnvSpec>;
using SimpleDiscreteEnvPool = PyEnvPool<mpe::SimpleDiscreteEnvPool>;

using SimpleContinuousEnvSpec = PyEnvSpec<mpe::SimpleContinuousEnvSpec>;
using SimpleContinuousEnvPool = PyEnvPool<mpe::SimpleContinuousEnvPool>;

using SimpleSpreadContinuousEnvSpec =
    PyEnvSpec<mpe::SimpleSpreadContinuousEnvSpec>;
using SimpleSpreadContinuousEnvPool =
    PyEnvPool<mpe::SimpleSpreadContinuousEnvPool>;

using SimpleSpreadDiscreteEnvSpec = PyEnvSpec<mpe::SimpleSpreadDiscreteEnvSpec>;
using SimpleSpreadDiscreteEnvPool = PyEnvPool<mpe::SimpleSpreadDiscreteEnvPool>;

PYBIND11_MODULE(mpe_envpool, m) {
  REGISTER(m, SimpleDiscreteEnvSpec, SimpleDiscreteEnvPool)
  REGISTER(m, SimpleContinuousEnvSpec, SimpleContinuousEnvPool)
  REGISTER(m, SimpleSpreadDiscreteEnvSpec, SimpleSpreadDiscreteEnvPool)
  REGISTER(m, SimpleSpreadContinuousEnvSpec, SimpleSpreadContinuousEnvPool)
}

BUILD

load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(default_visibility = ["//visibility:public"])

cc_library(
    name = "mpe_env",
    hdrs = [
        "default_params.h",
        "core.h",
        "scenario.h",
        "simple_env.h",
        "simple_spread.h",
    ],
    deps = [
        "//envpool/core:async_envpool",
        "@eigen",
    ],
)

pybind_extension(
    name = "mpe_envpool",
    srcs = ["mpe.cc"],
    deps = [
        ":mpe_env",
        "//envpool/core:py_envpool",
    ],
)

py_library(
    name = "mpe",
    srcs = ["__init__.py"],
    data = [":mpe_envpool.so"],
    deps = ["//envpool/python:api"],
)

py_library(
    name = "mpe_registration",
    srcs = ["registration.py"],
    deps = [
        "//envpool:registration",
    ],
)

registration.py

load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(default_visibility = ["//visibility:public"])

cc_library(
    name = "mpe_env",
    hdrs = [
        "default_params.h",
        "core.h",
        "scenario.h",
        "simple_env.h",
        # "simple_discrete_env.h",
        # "simple_continuous_env.h",
        "simple_spread.h",
    ],
    deps = [
        "//envpool/core:async_envpool",
        "@eigen",
    ],
)

pybind_extension(
    name = "mpe_envpool",
    srcs = ["mpe.cc"],
    deps = [
        ":mpe_env",
        "//envpool/core:py_envpool",
    ],
)

py_library(
    name = "mpe",
    srcs = ["__init__.py"],
    data = [":mpe_envpool.so"],
    deps = ["//envpool/python:api"],
)

py_library(
    name = "mpe_registration",
    srcs = ["registration.py"],
    deps = [
        "//envpool:registration",
    ],
)

__init__.py


from envpool.python.api import py_env

from .mpe_envpool import (
    _SimpleDiscreteEnvPool,
    _SimpleDiscreteEnvSpec,
    _SimpleContinuousEnvPool,
    _SimpleContinuousEnvSpec,
    _SimpleSpreadDiscreteEnvPool,
    _SimpleSpreadDiscreteEnvSpec,
    _SimpleSpreadContinuousEnvPool,
    _SimpleSpreadContinuousEnvSpec,
)

(
    SimpleDiscreteEnvSpec,
    SimpleDiscreteDMEnvPool,
    SimpleDiscreteGymEnvPool,
    SimpleDiscreteGymnasiumEnvPool,
) = py_env(_SimpleDiscreteEnvSpec, _SimpleDiscreteEnvPool)

(
    SimpleContinuousEnvSpec,
    SimpleContinuousDMEnvPool,
    SimpleContinuousGymEnvPool,
    SimpleContinuousGymnasiumEnvPool,
) = py_env(_SimpleContinuousEnvSpec, _SimpleContinuousEnvPool)

(
    SimpleSpreadDiscreteEnvSpec,
    SimpleSpreadDiscreteDMEnvPool,
    SimpleSpreadDiscreteGymEnvPool,
    SimpleSpreadDiscreteGymnasiumEnvPool,
) = py_env(_SimpleSpreadDiscreteEnvSpec, _SimpleSpreadDiscreteEnvPool)

(
    SimpleSpreadContinuousEnvSpec,
    SimpleSpreadContinuousDMEnvPool,
    SimpleSpreadContinuousGymEnvPool,
    SimpleSpreadContinuousGymnasiumEnvPool,
) = py_env(_SimpleSpreadContinuousEnvSpec, _SimpleSpreadContinuousEnvPool)

__all__ = [
    "SimpleDiscreteEnvSpec",
    "SimpleDiscreteDMEnvPool",
    "SimpleDiscreteGymEnvPool",
    "SimpleDiscreteGymnasiumEnvPool",
    "SimpleContinuousEnvSpec",
    "SimpleContinuousDMEnvPool",
    "SimpleContinuousGymEnvPool",
    "SimpleContinuousGymnasiumEnvPool",
    "SimpleSpreadDiscreteEnvSpec",
    "SimpleSpreadDiscreteDMEnvPool",
    "SimpleSpreadDiscreteGymEnvPool",
    "SimpleSpreadDiscreteGymnasiumEnvPool",
    "SimpleSpreadContinuousEnvSpec",
    "SimpleSpreadContinuousDMEnvPool",
    "SimpleSpreadContinuousGymEnvPool",
    "SimpleSpreadContinuousGymnasiumEnvPool",
]

Reproduction script

Below is modified from the test script for LunarLander

import os
import uuid

from dataclasses import dataclass, field
from pathlib import Path

import numpy as np

import envpool
import jax
import tyro

from rich.pretty import pprint

# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6"
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771
os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions"
os.environ["TF_CUDNN DETERMINISTIC"] = "1"

@dataclass
class Args:
    exp_name: str = Path(__file__).stem
    "the name of this experiment"
    seed: int = 1
    "seed of the experiment"
    track: bool = False
    # "if toggled, this experiment will be tracked with Weights and Biases"
    # wandb_project_name: str = "cleanRL"
    # "the wandb's project name"
    # wandb_entity: str = None
    # "the entity (team) of wandb's project"
    # capture_video: bool = False
    # "whether to capture videos of the agent performances (check out `videos` folder)"
    save_model: bool = False
    "whether to save model into the `runs/{run_name}` folder"
    upload_model: bool = False
    "whether to upload the saved model to huggingface"
    hf_entity: str = ""
    "the user or org name of the model repository from the Hugging Face Hub"
    log_frequency: int = 10
    "the logging frequency of the model performance (in terms of `updates`)"

    # Algorithm specific arguments
    # env_id: str = "Breakout-v5"
    env_id: str = "SimpleSpreadDiscrete-v0"
    "the id of the environment"
    total_timesteps: int = 50000000
    "total timesteps of the experiments"
    learning_rate: float = 2.5e-4
    "the learning rate of the optimizer"
    local_num_envs: int = 4
    "the number of parallel game environments"
    num_actor_threads: int = 2
    "the number of actor threads to use"
    num_steps: int = 128
    "the number of steps to run in each environment per policy rollout"
    anneal_lr: bool = True
    "Toggle learning rate annealing for policy and value networks"
    gamma: float = 0.99
    "the discount factor gamma"
    gae_lambda: float = 0.95
    "the lambda for the general advantage estimation"
    num_minibatches: int = 4
    "the number of mini-batches"
    gradient_accumulation_steps: int = 1
    "the number of gradient accumulation steps before performing an optimization step"
    update_epochs: int = 4
    "the K epochs to update the policy"
    norm_adv: bool = True
    "Toggles advantages normalization"
    clip_coef: float = 0.1
    "the surrogate clipping coefficient"
    ent_coef: float = 0.01
    "coefficient of the entropy"
    vf_coef: float = 0.5
    "coefficient of the value function"
    max_grad_norm: float = 0.5
    "the maximum norm for the gradient clipping"
    channels: list[int] = field(default_factory=lambda: [16, 32, 32])
    "the channels of the CNN"
    hiddens: list[int] = field(default_factory=lambda: [256])
    "the hiddens size of the MLP"

    actor_device_ids: list[int] = field(default_factory=lambda: [0])
    "the device ids that actor workers will use"
    learner_device_ids: list[int] = field(default_factory=lambda: [0])
    "the device ids that learner workers will use"
    distributed: bool = False
    "whether to use `jax.distirbuted`"
    concurrency: bool = False
    "whether to run the actor and learner concurrently"

    # runtime arguments to be filled in
    local_batch_size: int = 0
    local_minibatch_size: int = 0
    num_updates: int = 0
    world_size: int = 0
    local_rank: int = 0
    num_envs: int = 0
    batch_size: int = 0
    minibatch_size: int = 0
    global_learner_decices: list[str] | None = None
    actor_devices: list[str] | None = None
    learner_devices: list[str] | None = None

if __name__ == "__main__":
    args = tyro.cli(Args)
    args.local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids))
    args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
    assert args.local_num_envs % len(args.learner_device_ids) == 0, "local_num_envs must be divisible by len(learner_device_ids)"
    assert (
        int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0
    ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
    if args.distributed:
        jax.distributed.initialize(
            local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
        )
        print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))

    args.world_size = jax.process_count()
    args.local_rank = jax.process_index()
    args.num_envs = args.local_num_envs * args.world_size * args.num_actor_threads * len(args.actor_device_ids)
    args.batch_size = args.local_batch_size * args.world_size
    args.minibatch_size = args.local_minibatch_size * args.world_size
    args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
    local_devices = jax.local_devices()
    global_devices = jax.devices()
    learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
    actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
    global_learner_decices = [
        global_devices[d_id + process_index * len(local_devices)]
        for process_index in range(args.world_size)
        for d_id in args.learner_device_ids
    ]
    print("global_learner_decices", global_learner_decices)
    args.global_learner_decices = [str(item) for item in global_learner_decices]
    args.actor_devices = [str(item) for item in actor_devices]
    args.learner_devices = [str(item) for item in learner_devices]
    pprint(args)

    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{uuid.uuid4()}"

    num_envs = 4
    num_players = 3

    envs = envpool.make(
        args.env_id,
        env_type="gymnasium",
        num_envs=num_envs,
        max_num_players=num_players,
        num_agents=num_players,
        num_landmarks=3,
        seed=args.seed,
    )
    act_space = envs.action_space
    obs0, info = envs.reset()
    for _ in range(5000):
        if (_ + 1) % 250 == 0:
            print(f"iter {_}")
        # action = np.array([act_space.sample() for _ in range(args.local_num_envs)])
        action = np.array([act_space.sample() for _ in range(num_envs * num_players)])
        if (_ + 1) % 250 == 0:
            print(f"sending action {action} to environment")
        # obs0, rew0, terminated, truncated, info0 = envs.step(action[:, None], env_id=np.arange(1))
        obs0, rew0, terminated, truncated, info0 = envs.step(action.reshape(-1), env_id=np.arange(num_envs))
        if (_ + 1) % 250 == 0:
            print(f"reward {rew0.reshape(num_envs, -1).sum(-1)} from environment")
            print()
    envs.close()

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)

0.8.4 1.26.4 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] linux

Additional context

Add any other context about the problem here.

Checklist