Closed OmurAydogmus closed 2 months ago
Hi @OmurAydogmus, what's the issue exactly? Does the "test" code not work?
Thank you, btaba. I think I found the mistake. How can we normalize it. Because normalize_observations=True was selected during traning
# Load Policies and Test
paramsTEST = model.load_params('/tmp/params')
ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size)
make_inference = ppo.ppo_networks.make_inference_fn(ppoTEST)
inference_fnTEST = make_inference(paramsTEST)
env = envs.create(env_name=env_name, backend=backend)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fnTEST)
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(100):
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
media.show_video(env.render(rollout, camera='track'), fps=1.0 / env.dt)
I think it is okay. We nee to define preprocess_observations_fn using normalization.
ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size, preprocess_observations_fn=running_statistics.normalize)
I need to use the trained network in a separate test code without using the training-side code. I have a problem with the shape of inference_fn as below:
Test code: