instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
607 stars 77 forks source link

Add data augmentation for Sokoban #222

Open carlosgmartin opened 8 months ago

carlosgmartin commented 8 months ago

Add an option to apply data augmentation when sampling 10 × 10 Sokoban levels from the DeepMind Boxoban dataset.

More precisely, at reset, apply one of the 8 symmetries of the square uniformly at random to the loaded level. Here's an example implementation:

bits = random.bernoulli(key, shape=[3])
level = jnp.where(bits[0], level, level[::-1, :])  # vertical flip
level = jnp.where(bits[1], level, level[:, ::-1])  # horizontal flip
level = jnp.where(bits[2], level, jnp.rot90(level))  # 90-degree rotation