epignatelli / navix

Accelerated minigrid environments with JAX
Apache License 2.0
116 stars 10 forks source link

The current solution for `State` is not scalable to other, new entities. #33

Closed epignatelli closed 1 year ago

epignatelli commented 1 year ago

The current solution for State is not scalable to other, new entities.

Consider replacing the state set of players doors, keys goals with a single (fixed-length at compile time) collection of entities.

From this:

class State(struct.PyTreeNode):
    """The Markovian state of the environment"""

    key: KeyArray
    """The random number generator state"""
    grid: Array
    """The base map of the environment that remains constant throughout the training"""
    cache: RenderingCache
    """The rendering cache to speed up rendering"""
    players: Player = Player.create()
    """The player entity"""
    goals: Goal = Goal.create()
    """The goal entity, batched over the number of goals"""
    keys: Key = Key.create()
    """The key entity, batched over the number of keys"""
    doors: Door = Door.create()

To this:

class State(struct.PyTreeNode):
    """The Markovian state of the environment"""

    key: KeyArray
    """The random number generator state"""
    grid: Array
    """The base map of the environment that remains constant throughout the training"""
    cache: RenderingCache
    """The rendering cache to speed up rendering"""
    entities: Tuple[Entity, ...]

The main obstacle is the computational cost of iterating though the list when we need to extract a specific entity, like a player for the action, for example.

epignatelli commented 1 year ago

Solved by #39