epignatelli / navix

Accelerated minigrid environments with JAX
Apache License 2.0
103 stars 7 forks source link

Rendering two sprites in the same cell #34

Open epignatelli opened 1 year ago

epignatelli commented 1 year ago
import jax.numpy as jnp
import navix as nx
import matplotlib.pyplot as plt

grid = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.uint8)
goal = nx.entities.Goal.create(position=jnp.asarray((0, 0)), probability=jnp.asarray(1.0))
player = nx.entities.Player.create(position=jnp.asarray((0, 0)))

positions = jnp.stack([goal.position[0], player.position])
sprites = jnp.stack([goal.get_sprite(nx.graphics.SPRITES_REGISTRY)[0], player.get_sprite(nx.graphics.SPRITES_REGISTRY)])
image = grid.at[tuple(positions.T)].set(sprites)
image = jnp.swapaxes(image, 1, 2)
image = image.reshape(32, 32, 3)
plt.imshow(image)

image