Closed btaba closed 1 month ago
@btaba Thank you for your works.
Would you let us know when the updates will be for this issue?
In addtion, is there any plan to work for deploying inference function or params/networks to c++ such as torch stript?
thank you.
Hello,
Thank you for creating the issue @btaba from the original discussion.
Will this sort of functionality or @jihan1218 's solution become part of the main branch at some point?
Thanks!
Should be fixed in b164655
Here's an example:
from orbax import checkpoint as ocp
from flax.training import orbax_utils
ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)
def policy_params_fn(current_step, make_policy, params):
# save checkpoints
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f'{current_step}'
orbax_checkpointer.save(path, params, force=True, save_args=save_args)
train_fn = functools.partial(
ppo.train, num_timesteps=100_000_000,
policy_params_fn=policy_params_fn,
restore_checkpoint_path=ckpt_path / '11141120' # to restart from a previous checkpoint
)
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
You can recover the inference fn without training like so:
make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)
And the params can be loaded from the checkpoint using orbax.
@btaba Hi.
I ran my code with orbax_checkpointer like the example above and got '_CHECKPOINT_METADATA', '_METADATA', and 'checkpoint' files. Then, when i tried to restore the checkpoint, i got the error such as
136 print("*********************************")
138 evalEnv = envs.GetEnvironment(self.envName)
--> 140 make_inference_fn, self.params, self.metrics = train_fn(environment=self.env, progress_fn=ProgressCallback, eval_env=evalEnv, policy_params_fn=PolicyCallback)
142 # create data frame with train_rewards and steps
143 df = pd.DataFrame({'Steps': trainSteps, 'Rewards': trainRewards, 'RewardsError': trainRewardsErr})
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/brax/training/agents/ppo/train.py:399, in train(environment, num_timesteps, episode_length, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn, restore_checkpoint_path)
397 orbax_checkpointer = ocp.PyTreeCheckpointer()
398 target = training_state.normalizer_params, init_params
--> 399 (normalizer_params, init_params) = orbax_checkpointer.restore(
400 restore_checkpoint_path, item=target
401 )
402 training_state = training_state.replace(
403 normalizer_params=normalizer_params, params=init_params
404 )
406 training_state = jax.device_put_replicated(
407 training_state,
408 jax.local_devices()[:local_devices_to_use])
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:211, in Checkpointer.restore(self, directory, *args, **kwargs)
209 logging.info('Restoring item from %s.', directory)
210 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 211 restored = self._handler.restore(directory, args=ckpt_args)
212 logging.info('Finished restoring checkpoint from %s.', directory)
213 multihost.sync_global_processes(
214 multihost.unique_barrier_key(
215 'Checkpointer:restore',
(...)
219 processes=self._active_processes,
220 )
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:637, in PyTreeCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, legacy_transform_fn, args)
628 if (
629 (directory / _METADATA_FILE).exists()
630 and transforms is None
631 and legacy_transform_fn is None
632 ):
633 args = BasePyTreeRestoreArgs(
634 item,
635 restore_args=restore_args,
636 )
--> 637 return self._handler_impl.restore(directory, args=args)
639 logging.debug('directory=%s, restore_args=%s', directory, restore_args)
640 if not directory.exists():
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py:800, in BasePyTreeCheckpointHandler.restore(self, directory, args)
794 logging.debug(
795 'ts_metrics: %s',
796 json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')),
797 )
799 if item is not None:
--> 800 return utils.deserialize_tree(restored_item, item)
801 return restored_item
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/utils.py:269, in deserialize_tree(serialized, target, keep_empty_nodes)
266 result = result[key_name]
267 return result
--> 269 return jax.tree_util.tree_map_with_path(
270 _reconstruct_from_keypath,
271 target,
272 is_leaf=is_empty_or_leaf if keep_empty_nodes else None,
273 )
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py:1001, in tree_map_with_path(f, tree, is_leaf, *rest)
999 keypath_leaves = list(zip(*keypath_leaves))
1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py:1001, in <genexpr>(.0)
999 keypath_leaves = list(zip(*keypath_leaves))
1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/utils.py:266, in deserialize_tree.<locals>._reconstruct_from_keypath(keypath, _)
264 if not isinstance(result, list) and key_name not in result:
265 key_name = str(key_name)
--> 266 result = result[key_name]
267 return result
KeyError: 'policy'
'''
can you give me any hints to solve the problem?
Thank you!
Hi @Rian-Jo , are you saving the full params in your checkpoint? It looks like the "policy" key is missing. Take a look at the colab example here to see what the diff is:
https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb
@btaba I got it. Thank you so much!!
One more question.
Is there any difference between the params resulted from the train function and its saveed/loaded params? simulation playback with first and last seem different. And also, is there also any difference between the make_inference_fn from train function and the one from jax.xla_computation?
Thank you.
Can you be a bit more specific about what the "saveed/loaded params" are? What is the make_inference_fn from jax.xla_computation ? Not understanding what the question is
Hi @btaba,
Here's an example:
from orbax import checkpoint as ocp from flax.training import orbax_utils ckpt_path = epath.Path('/tmp/some-env/ckpts') ckpt_path.mkdir(parents=True, exist_ok=True) def policy_params_fn(current_step, make_policy, params): # save checkpoints orbax_checkpointer = ocp.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(params) path = ckpt_path / f'{current_step}' orbax_checkpointer.save(path, params, force=True, save_args=save_args) train_fn = functools.partial( ppo.train, num_timesteps=100_000_000, policy_params_fn=policy_params_fn, restore_checkpoint_path=ckpt_path / '11141120' # to restart from a previous checkpoint ) make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
You can recover the inference fn without training like so:
make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)
I have three jit_inference_fn.
trained make_inference_fn
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
jit_inference_fn = jax.jit(make_inference_fn(params))
model.save_params(model_path, params)
restored make_inference_fn with saved_params
make_inference_fn, params, _= train_fn(environment=env, num_timesteps=0)
params = model.load_params(model_path)
jit_inference_fn = jax.jit(make_inference_fn(params))
3. a wrapped version of jit_inference_fn by using [jax.xla_computation](https://jax.readthedocs.io/en/latest/_autosummary/jax.xla_computation.html)
c = jax.xla_computation(jit_inference_fn)(...)
i use c from case 3 in c++.
should these three be functioning same? i got the results from each are different..
i wander that there are some chnages during the process,
(case 2) when the params is saved and loaded or call train_fn with 'num_timesteps=0'
(case 3) when the function is wrapped with xla_computation.
-------------
> https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb
Meanwhile, in quadruped example, what is the meaning of editing state value (qvel)? i understand the qvel is the sensored value. how can random noise kick to floating base state (qvel[:2]) ensure the legs kick?
i imagined if random noise is added to the floating base state and pipeline_step runs, the floating base jumps to the randomly noised direction continuosly not legs.
... def step(self, state: State, action: jax.Array) -> State: # pytype: disable=signature-mismatch rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)
# kick
push_interval = 10
kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
kick *= jp.mod(state.info['step'], push_interval) == 0
qvel = state.pipeline_state.qvel # pytype: disable=attribute-error
qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
state = state.tree_replace({'pipeline_state.qvel': qvel})
# physics step
motor_targets = self._default_pose + action * self._action_scale
motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
x, xd = pipeline_state.x, pipeline_state.xd
...
Thank you.
Discussed in https://github.com/google/brax/discussions/403