TuragaLab / flybody

MuJoCo fruit fly body model and reinforcement learning tasks
Apache License 2.0
96 stars 9 forks source link

Training Checkpoint Loading Error #7

Closed scott-yj-yang closed 3 months ago

scott-yj-yang commented 3 months ago

Hi Team,

When I tried to load checkpoint to continue training via the DMPOConfig.checkpoint_to_load = {run_name}/checkpoints/dmpo_learner/ckpt-4.index, yielded the following AssertionError inside create_dual_variables_once.

ray.exceptions.RayTaskError(AssertionError): ray::Learner.run() (pid=3441185, ip=10.244.10.88, actor_id=155580693110898c9f9447d701000000, repr=<flybody.agents.ray_distributed_dmpo.Learner object at 0x7f0c90c71240>)
  File "/root/vast/scott-yang/flybody/flybody/agents/ray_distributed_dmpo.py", line 208, in run
    self.step()
  File "/root/vast/scott-yang/flybody/flybody/agents/learning_dmpo.py", line 324, in step
    fetches = self._step()
  File "/root/vast/scott-yang/flybody/flybody/agents/ray_distributed_dmpo.py", line 201, in _step
    return DistributionalMPOLearner._step(self)
  File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1147, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

    File "/root/vast/scott-yang/flybody/flybody/agents/learning_dmpo.py", line 274, in _step  *
        policy_loss, policy_stats = self._policy_loss_module(
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/root/vast/scott-yang/flybody/flybody/agents/losses_mpo.py", line 216, in __call__  *
        self.create_dual_variables_once(dual_variable_shape, scalar_dtype)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/once.py", line 89, in wrapper  *
        _check_no_output(wrapped(*args, **kwargs))
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/root/anaconda3/envs/flybody/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/root/vast/scott-yang/flybody/flybody/agents/losses_mpo.py", line 142, in create_dual_variables_once  *
        self._log_temperature = tf.Variable(

    AssertionError:

This error message from tf.Variable is not particularly helpful since I cannot find any actual asserts to fix my code. However, the training can be run without loading the checkpoint.

Looking deeper into the checkpoint loading logic, I don't quite understand the logic at the this line, where the variable _checkpoint is created and restored but never used.

Could you help me with this issue? Any clarification and guidance is greatly appreciated!

Scott Yang.

vaxenburg commented 3 months ago

Before we look further, can you try the same but without the .index at the end of the checkpoint name?

scott-yj-yang commented 3 months ago

Thank you so much for your quick reply! That fixed the issue! I can't believe the issue is the checkpoint file name lolll. I guess I am not so familiar with the tf's checkpoint naming system. Anyway, thanks for your help!