Closed epignatelli closed 1 year ago
import matplotlib.pyplot as plt import navix as nx import jax env = nx.environments.KeyDoor(12, 6, 100, observation_fn=nx.observations.rgb) key = jax.random.PRNGKey(0) timestep = env.reset(key) plt.imshow(timestep.observation) plt.show()