google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

How do I load a pre-trained model? #438

Closed btaba closed 1 month ago

btaba commented 6 months ago

Discussed in

Originally posted by **eleninisioti** October 11, 2023 There is a notebook that explains how to save and load models ( but there testing happens right after training, calling function `make_inference_fn(params)`, which requires first running `make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)`, the whole training process. My question is how I can test without running the training process, simply by loading the params. How can I have `inference_fn` withour training?
Rian-Jo commented 4 months ago

@btaba Thank you for your works.

Would you let us know when the updates will be for this issue?

In addtion, is there any plan to work for deploying inference function or params/networks to c++ such as torch stript?

thank you.

willthibault commented 1 month ago


Thank you for creating the issue @btaba from the original discussion.

Will this sort of functionality or @jihan1218 's solution become part of the main branch at some point?


btaba commented 1 month ago

Should be fixed in b164655

Here's an example:

from orbax import checkpoint as ocp
from import orbax_utils

ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}', params, force=True, save_args=save_args)

train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000,
      restore_checkpoint_path=ckpt_path / '11141120'  # to restart from a previous checkpoint

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

You can recover the inference fn without training like so:

make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)

And the params can be loaded from the checkpoint using orbax.

Rian-Jo commented 3 weeks ago

@btaba Hi.

I ran my code with orbax_checkpointer like the example above and got '_CHECKPOINT_METADATA', '_METADATA', and 'checkpoint' files. Then, when i tried to restore the checkpoint, i got the error such as

    136 print("*********************************")
    138 evalEnv = envs.GetEnvironment(self.envName)
--> 140 make_inference_fn, self.params, self.metrics = train_fn(environment=self.env, progress_fn=ProgressCallback, eval_env=evalEnv, policy_params_fn=PolicyCallback)
    142 # create data frame with train_rewards and steps
    143 df = pd.DataFrame({'Steps': trainSteps, 'Rewards': trainRewards, 'RewardsError': trainRewardsErr})

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/brax/training/agents/ppo/, in train(environment, num_timesteps, episode_length, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn, restore_checkpoint_path)
    397   orbax_checkpointer = ocp.PyTreeCheckpointer()
    398   target = training_state.normalizer_params, init_params
--> 399   (normalizer_params, init_params) = orbax_checkpointer.restore(
    400       restore_checkpoint_path, item=target
    401   )
    402   training_state = training_state.replace(
    403       normalizer_params=normalizer_params, params=init_params
    404   )
    406 training_state = jax.device_put_replicated(
    407     training_state,
    408     jax.local_devices()[:local_devices_to_use])

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/, in Checkpointer.restore(self, directory, *args, **kwargs)
    209'Restoring item from %s.', directory)
    210 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 211 restored = self._handler.restore(directory, args=ckpt_args)
    212'Finished restoring checkpoint from %s.', directory)
    213 multihost.sync_global_processes(
    214     multihost.unique_barrier_key(
    215         'Checkpointer:restore',
    219     processes=self._active_processes,
    220 )

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/, in PyTreeCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, legacy_transform_fn, args)
    628 if (
    629     (directory / _METADATA_FILE).exists()
    630     and transforms is None
    631     and legacy_transform_fn is None
    632 ):
    633   args = BasePyTreeRestoreArgs(
    634       item,
    635       restore_args=restore_args,
    636   )
--> 637   return self._handler_impl.restore(directory, args=args)
    639 logging.debug('directory=%s, restore_args=%s', directory, restore_args)
    640 if not directory.exists():

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/, in BasePyTreeCheckpointHandler.restore(self, directory, args)
    794   logging.debug(
    795       'ts_metrics: %s',
    796       json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')),
    797   )
    799 if item is not None:
--> 800   return utils.deserialize_tree(restored_item, item)
    801 return restored_item

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/, in deserialize_tree(serialized, target, keep_empty_nodes)
    266     result = result[key_name]
    267   return result
--> 269 return jax.tree_util.tree_map_with_path(
    270     _reconstruct_from_keypath,
    271     target,
    272     is_leaf=is_empty_or_leaf if keep_empty_nodes else None,
    273 )

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/, in tree_map_with_path(f, tree, is_leaf, *rest)
    999 keypath_leaves = list(zip(*keypath_leaves))
   1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/, in <genexpr>(.0)
    999 keypath_leaves = list(zip(*keypath_leaves))
   1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/, in deserialize_tree.<locals>._reconstruct_from_keypath(keypath, _)
    264   if not isinstance(result, list) and key_name not in result:
    265     key_name = str(key_name)
--> 266   result = result[key_name]
    267 return result

KeyError: 'policy'

can you give me any hints to solve the problem?

Thank you!
btaba commented 3 weeks ago

Hi @Rian-Jo , are you saving the full params in your checkpoint? It looks like the "policy" key is missing. Take a look at the colab example here to see what the diff is:

Rian-Jo commented 3 weeks ago

@btaba I got it. Thank you so much!!

One more question.

Is there any difference between the params resulted from the train function and its saveed/loaded params? simulation playback with first and last seem different. And also, is there also any difference between the make_inference_fn from train function and the one from jax.xla_computation?

Thank you.

btaba commented 3 weeks ago

Can you be a bit more specific about what the "saveed/loaded params" are? What is the make_inference_fn from jax.xla_computation ? Not understanding what the question is

Rian-Jo commented 3 weeks ago

Hi @btaba,

Here's an example:

from orbax import checkpoint as ocp
from import orbax_utils

ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}', params, force=True, save_args=save_args)

train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000,
      restore_checkpoint_path=ckpt_path / '11141120'  # to restart from a previous checkpoint

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

You can recover the inference fn without training like so:

make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)

I have three jit_inference_fn.

  1. trained make_inference_fn

    make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
    jit_inference_fn = jax.jit(make_inference_fn(params))
    model.save_params(model_path, params)
  2. restored make_inference_fn with saved_params

    make_inference_fn, params, _= train_fn(environment=env, num_timesteps=0)
    params = model.load_params(model_path)

jit_inference_fn = jax.jit(make_inference_fn(params))

3. a wrapped version of jit_inference_fn by using [jax.xla_computation](

c = jax.xla_computation(jit_inference_fn)(...)

i use c from case 3 in c++.

should these three be functioning same? i got the results from each are different.. 
i wander that there are some chnages during the process,

(case 2) when the params is saved and loaded or call train_fn with 'num_timesteps=0' 
(case 3) when the function is wrapped with xla_computation.



Meanwhile, in quadruped example, what is the meaning of editing state value (qvel)? i understand the qvel is the sensored value. how can random noise kick to floating base state (qvel[:2]) ensure the legs kick? 

i imagined if random noise is added to the floating base state and pipeline_step runs, the floating base jumps to the randomly noised direction continuosly not legs.

... def step(self, state: State, action: jax.Array) -> State: # pytype: disable=signature-mismatch rng, cmd_rng, kick_noise_2 = jax.random.split(['rng'], 3)

# kick
push_interval = 10
kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
kick *= jp.mod(['step'], push_interval) == 0
qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error
qvel =[:2].set(kick * self._kick_vel + qvel[:2])
state = state.tree_replace({'pipeline_state.qvel': qvel})

# physics step
motor_targets = self._default_pose + action * self._action_scale
motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
x, xd = pipeline_state.x, pipeline_state.xd


Thank you.