tensortrade-org / tensortrade

An open source reinforcement learning framework for training, evaluating, and deploying robust trading agents.
https://discord.gg/ZZ7BGWh
Apache License 2.0
4.52k stars 1.02k forks source link

"Size mismatch" error when restoring checkpoint after Ray tuning #437

Closed fridary closed 2 years ago

fridary commented 2 years ago

I pasted code "Train and Evaluate using Ray" https://www.tensortrade.org/en/latest/examples/train_and_evaluate_using_ray.html, tuning works good. After this I added code from "Using Ray with TensorTrade" https://www.tensortrade.org/en/latest/tutorials/ray.html to restore checkpoints:

import ray.rllib.agents.ppo as ppo

# Get checkpoint
checkpoints = analysis.get_trial_checkpoints_paths(
    trial=analysis.get_best_trial("episode_reward_mean"),
    metric="episode_reward_mean",
    mode='max'
)
checkpoint_path = checkpoints[0][0]

# Restore agent
agent = ppo.PPOTrainer(
    env="TradingEnv",
    config={
        "env_config": {
            "window_size": 25
        },
        "framework": "torch",
        "log_level": "DEBUG",
        "ignore_worker_failures": True,
        "num_workers": 1,
        "num_gpus": 0,
        "clip_rewards": True,
        "lr": 8e-6,
        "lr_schedule": [
            [0, 1e-1],
            [int(1e2), 1e-2],
            [int(1e3), 1e-3],
            [int(1e4), 1e-4],
            [int(1e5), 1e-5],
            [int(1e6), 1e-6],
            [int(1e7), 1e-7]
        ],
        "gamma": 0,
        "observation_filter": "MeanStdFilter",
        "lambda": 0.72,
        "vf_loss_coeff": 0.5,
        "entropy_coeff": 0.01
    }
)
agent.restore(checkpoint_path)

And this gets me an error, I have not found any same issue over Internet:

Traceback (most recent call last):
  File "test.py", line 195, in <module>
    agent.restore(checkpoint_path)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/tune/trainable.py", line 467, in restore
    self.load_checkpoint(checkpoint_path)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1823, in load_checkpoint
    self.__setstate__(extra_data)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 2443, in __setstate__
    self.workers.local_worker().restore(state["worker"])
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1346, in restore
    self.policy_map[pid].set_state(state)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 722, in set_state
    super().set_state(state)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 638, in set_state
    self.set_weights(state["weights"])
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 675, in set_weights
    self.model.load_state_dict(weights)
  File "/Users/fridary/miniforge3/envs/rl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ComplexInputNetwork:
    size mismatch for post_fc_stack._value_branch._model.0.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 256]).
    size mismatch for logits_layer._model.0.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([2, 256]).
    size mismatch for value_layer._model.0.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 256]).

Any ideas what's wrong? If I change "framework" to "tf", there are no errors.

atomcracker commented 2 years ago

Put analysis.get_best_config(metric=metric, mode=mode) as the config in your ppo agent.