google-deepmind / hanabi-learning-environment

hanabi_learning_environment is a research platform for Hanabi experiments.
Apache License 2.0
645 stars 149 forks source link

Defining Non-canonical (Float) Observations #44

Open 0xJchen opened 2 years ago

0xJchen commented 2 years ago

Hi, thanks for the great project! I noticed that if we were using float element in the observation vector, we should use custom objects. A snippet of cpp end likes this:

char* EncodeObservation(pyhanabi_observation_encoder_t* encoder,
                        pyhanabi_observation_t* observation) {
  REQUIRE(encoder != nullptr);
  REQUIRE(encoder->encoder != nullptr);
  REQUIRE(observation != nullptr);
  REQUIRE(observation->observation != nullptr);
  auto obs_enc = reinterpret_cast<hanabi_learning_env::ObservationEncoder*>(
      encoder->encoder);
  auto obs = reinterpret_cast<hanabi_learning_env::HanabiObservation*>(
      observation->observation);
  std::vector<int> encoding = obs_enc->Encode(*obs);
  std::string obs_str = "";
  for (int i = 0; i < encoding.size(); i++) {
    obs_str += (encoding[i] ? "1" : "0");
    if (i != encoding.size() - 1) {
      obs_str += ",";
    }
  }
  return strdup(obs_str.c_str());
}

And the python-end likes this:

  def encode(self, observation):
    """Encode the observation as a sequence of bits."""
    c_encoding_str = lib.EncodeObservation(self._encoder,
                                           observation.observation())
    encoding_string = encode_ffi_string(c_encoding_str)
    lib.DeleteString(c_encoding_str)
    # Canonical observations are bit strings, so it is ok to encode using a
    # string. For float or double observations, make a custom object
    encoding = [int(x) for x in encoding_string.split(",")]
    return encoding

I understand the current implementation only deals with int observation elements. They are first convert to "01"strings and then decoded in python with cffi. As for `float, I tried to replace obs_str += (encoding[i] ? "1" : "0"); with obs_str += std::to_string(encoding[i] ) (assuming the contents of encoding are floats). But what python-end decoded are not floats. I wonder if there are any examples demonstrating how to deal with float observations?