google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Trained model inference #519

Closed OmurAydogmus closed 2 months ago

OmurAydogmus commented 2 months ago

I need to use the trained network in a separate test code without using the training-side code. I have a problem with the shape of inference_fn as below:

# training code
inference_fn
Out[10]: <function brax.training.agents.ppo.networks.make_inference_fn.<locals>.make_policy.<locals>.policy(observations: jax.Array, key_sample: jax.Array) -> Tuple[jax.Array, Mapping[str, Any]]>
# test code
inference_fn
Out[181]: <function brax.training.agents.ppo.networks.make_inference_fn.<locals>.make_policy(params: Tuple[Any, Any], deterministic: bool = False) -> brax.training.types.Policy>

Test code:

# Load the saved parameters
model_path = '/tmp/mjx_brax_quadruped_policy'
params = model.load_params(model_path)

# Create the environment (ensure this matches the original environment from Code1)
observation_size = env.observation_size
action_size = env.action_size

# Re-create the policy and value networks with the exact same architecture
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    observation_size=observation_size,
    action_size=action_size,
    policy_hidden_layer_sizes=(128, 128, 128, 128)  # Match the architecture from Code1
)

# Initialize the PPO networks (policy and value networks)
ppo_modified = make_networks_factory()

# Now use `make_inference_fn` to create the inference function
inference_fn = ppo_networks.make_inference_fn(ppo_modified)

# JIT-compile the inference function
jit_inference_fn = jax.jit(inference_fn(params)) 
btaba commented 2 months ago

Hi @OmurAydogmus, what's the issue exactly? Does the "test" code not work?

OmurAydogmus commented 2 months ago

Thank you, btaba. I think I found the mistake. How can we normalize it. Because normalize_observations=True was selected during traning

# Load Policies and Test
paramsTEST = model.load_params('/tmp/params')

ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size)
make_inference = ppo.ppo_networks.make_inference_fn(ppoTEST)
inference_fnTEST = make_inference(paramsTEST)

env = envs.create(env_name=env_name, backend=backend)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fnTEST)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(100):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)

media.show_video(env.render(rollout, camera='track'), fps=1.0 / env.dt)
OmurAydogmus commented 2 months ago

I think it is okay. We nee to define preprocess_observations_fn using normalization.

ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size, preprocess_observations_fn=running_statistics.normalize)