entity-neural-network / entity-gym

Standard interface for entity based reinforcement learning environments.
Other
35 stars 5 forks source link

[api] Categorical features #2

Open cswinter opened 2 years ago

cswinter commented 2 years ago

Currently, all features are assumed to be scalars. We should add support for categorical features and one-hot encoding as well.

Theomat commented 2 years ago

So I tried the following implementation:

@dataclass
class Feature:
    name: str
    shape: List[int] = dataclasses.field(default_factory=lambda:[1])
    type: Type = dataclasses.field(default=float)

    def match(self, value: Union[float, np.ndarray], batched: bool = False) -> bool:
        if isinstance(value, float):
            return not batched and len(self.shape) == 1 and self.shape[0] == 1
        return value.shape == self.shape if not batched else value.shape[1:] == self.shape

class CategoricalFeature(Feature):
    categories: int

    def __init__(self, name: str, categories: int) -> None:
        super().__init__(name, [1], type=int)
        self.categories = categories

@dataclass(frozen=False)
class FeatureValue:
    feature: Feature
    value: Union[float, np.ndarray]

    def as_array(self) -> np.ndarray:
        if isinstance(self.value, float):
            return np.array([self.value])
        return self.value

@dataclass
class Observation:
    entities: Dict[str, List[FeatureValue]]
    """Maps each entity type to an array with the features for each observed entity of that type."""
    ids: Sequence[EntityID]
    """
    Maps each entity index to an opaque identifier used by the environment to
    identify that entity.
    """
    action_masks: Mapping[str, ActionMask]
    """Maps each action to an action mask."""
    reward: float
    done: bool
    end_of_episode_info: Optional[EpisodeStats] = None

@dataclass
class Entity:
    features: List[Feature]

But then we run into a number of issues:

  1. For the MultiSnake environment for example the color feature is categorical but the number of categories depend on the environment, so the obs_space method cannot give an accurate representation of the feature space since this is a class method. I see two solutions move it as an object method or let's say that -1 is a special code for a CategoricalFeature to indicate that the number of categories is not known.

  2. Let's look at a subset of the code above in more depth:

    
    @dataclass(frozen=False)
    class FeatureValue:
    feature: Feature
    value: Union[float, np.ndarray]
    
    def as_array(self) -> np.ndarray:
        if isinstance(self.value, float):
            return np.array([self.value])
        return self.value

@dataclass class Observation: entities: Dict[str, List[FeatureValue]]


We probably do not want to create a `FeatureValue` object at each step for each `Feature` for a number of reasons, this is why there is `@dataclass(frozen=False)` to indicate that this object is mutable thus we can create only one for each feature for the whole life of the environment. But as the same time this complicates thing in `Observation` that is now made of `FeatureValue` since you cannot directly save `Observation` anymore because `FeatureValue` are mutable. So clearly this representation has issues.
A solution could be that a `FeatureValue` becomes `FeatureValue = Tuple[Feature, Any]`? But again we would create `Tuple` for each feature at each step, which is basically similar to just putting back `@dataclass(frozen=True)`.

*This is more of a comment to raise potential issues to consider than an implementation proposal.*
cswinter commented 2 years ago

For environments like MultiSnake, I think it's perfectly fine to just set an upper cap on the color (e.g. 32). Though, it does seem like we want to change the abstraction somewhat to allow for more configurable environments whose feature/action spaces are not fixed.

The most obvious solution to how to represent Observations is to still just keep them as a single flat np.ndarray, with one-hot encoding for categorical features. This avoids the FeatureValue problem.

E.g., MultiSnake could look like this:

obs_space = {
    "food": Entity(
        Scalar("x"), # maybe we want to have a shorthand where we just use a `str` for scalar features?
        Scalar("y"),
        Categorical("Color", 4),
    ),
    # ...
}
observation = {
    "food": array([
        [10.0, 6.0, 0.0, 1.0, 0.0, 0.0],
        [10.0, 6.0, 1.0, 0.0, 0.0, 0.0],
    ]),
    # ...
}

I think this would be quite reasonable to implement already. We'll likely need to do something special for very high-cardinality categorical features and other more exotic data types, but we can cross that bridge when we come to it.