Open epignatelli opened 1 year ago
import jax.numpy as jnp import navix as nx import matplotlib.pyplot as plt grid = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.uint8) goal = nx.entities.Goal.create(position=jnp.asarray((0, 0)), probability=jnp.asarray(1.0)) player = nx.entities.Player.create(position=jnp.asarray((0, 0))) positions = jnp.stack([goal.position[0], player.position]) sprites = jnp.stack([goal.get_sprite(nx.graphics.SPRITES_REGISTRY)[0], player.get_sprite(nx.graphics.SPRITES_REGISTRY)]) image = grid.at[tuple(positions.T)].set(sprites) image = jnp.swapaxes(image, 1, 2) image = image.reshape(32, 32, 3) plt.imshow(image)