takuseno / d3rlpy

An offline deep reinforcement learning library
https://takuseno.github.io/d3rlpy
MIT License
1.32k stars 242 forks source link

NaN in Predictions while online finetune #211

Closed lettersfromfelix closed 2 years ago

lettersfromfelix commented 2 years ago

Hi @takuseno , First of all thanks again for your awesome work, I was able to train my agent in a custom environment with your help and already increased the performance significantly! Nevertheless, I wanted to fine tune the agent in an online environment. Unfortunately. this worked for only somewhere between 500-1000 steps (not fixed, seems arbitrary) until I get an AssertionError because NaN values are predicted. I get the following trace. Any idea where I could look into / fix this?

Exception has occurred: ValueError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Expected parameter loc (Tensor of shape (1, 4)) of distribution Normal(loc: torch.Size([1, 4]), scale: torch.Size([1, 4])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, nan]])
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/torch/distributions/normal.py", line 54, in __init__
    super(Normal, self).__init__(batch_shape, validate_args=validate_args)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/models/torch/distributions.py", line 99, in __init__
    self._dist = Normal(self._mean, self._std)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/models/torch/policies.py", line 175, in dist
    return SquashedGaussianDistribution(mu, clipped_logstd.exp())
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/models/torch/policies.py", line 189, in forward
    dist = self.dist(x)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/models/torch/policies.py", line 245, in best_action
    action = self.forward(x, deterministic=True, with_log_prob=False)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/algos/torch/ddpg_impl.py", line 195, in _predict_best_action
    return self._policy.best_action(x)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/algos/torch/base.py", line 58, in predict_best_action
    action = self._predict_best_action(x)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/torch_utility.py", line 295, in wrapper
    return f(self, *tensors, **kwargs)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/torch_utility.py", line 305, in wrapper
    return f(self, *args, **kwargs)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/algos/base.py", line 127, in predict
    return self._impl.predict_best_action(x)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/online/explorers.py", line 50, in sample
    greedy_actions = algo.predict(x)
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/online/iterators.py", line 212, in train_single_env
    action = explorer.sample(algo, x, total_step)[0]
  File "/home/user/ws/d3/.venv/lib/python3.10/site-packages/d3rlpy/algos/base.py", line 251, in fit_online
    train_single_env(
  File "/home/user/ws/d3/simulation/examples/tune_d3rlpy.py", line 78, in <module>
    cql.fit_online(env, buffer, explorer, n_steps=1000)
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I used following script to initiate fine-tuning:

cql = d3rlpy.algos.CQL(use_gpu=False, action_scaler=action_scaler, scaler=scaler)
cql.build_with_env(env)
cql.load_model("model_43596.pt")

buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)
explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(0.1)
cql.fit_online(env, buffer, explorer, n_steps=1000)
takuseno commented 2 years ago

It looks like you're using scaler's, but internal states of scaler and action_scaler need to be fitted with your dataset. I suggest you setup your model by loading params.json saved in offline training:

cql = d3rlpy.algos.CQL.from_json("d3rlpy_logs/xxx/params.json")

Also, if you want to finetune the CQL policy, I'd recommend finetuning by SAC:

sac = d3rlpy.algos.SAC(scaler=cql.scaler, action_scaler=cql.action_scaler)
sac.copy_policy_from(cql)
sac.copy_q_function_from(cql)

Please see more details in the documentation.

lettersfromfelix commented 2 years ago

Thanks for the fast reply! Actually I had initialized the scalers before using them with the same prerecorded dataset I used beforehand, just forgot to include that in the snippet. But even when I remove them, or if I construct the algorithm from json as you stated, I still get the same error :/ Also when using SAC instead of CQL for fine tuning, still the same behavior. (Btw it seems like in your snippet and the snippet in the readme there's a sac.build_with_env(env) missing before calling the copy_policy function, at least I'm getting an AssertionError telling me so)

lettersfromfelix commented 2 years ago

A few more insights I gathered while debugging:

takuseno commented 2 years ago

I see. If you try training SAC from scratch without finetuning and still see NaN errors, the NaN value definitely comes from your environment. Please check this.

lettersfromfelix commented 2 years ago

Makes 100% sense, but I'm still stuck at checking my env for that. Surprisingly when training with standard OpenAI Gym and SB3, I never get an issue for millions of steps, but as soon as I use plain d3rlpy sac online training it fails with the above error. Do you happen to have any hints on how to check my env for that? Observations, Actions, Rewards, and terminals are always real (not NaN or Inf). I already removed terminal states and just use the TimeLimit Wrapper for end of episodes, but even with that and timelimit_aware=True, I still get the error at about 25% of executed reset steps :(

Edit: Found another thing! When using the NormalNoise Explorer instead of the ConstantEpsilonGreedy one, the problems seems to have gone away (at least currently at 30.000 steps while before it aborted at latest at 1.200). Sorry if that's a dumb idea, I'm not really familiar with the explorers, but could it be caused by the fact that the clipping of the action_scaler min and max is only preset in the NormalNoise Explorer? (Which still wouldn't explain why the error occurs only on reset steps) Edit2: Plus, this also only holds true if I use SAC online from scratch. As soon as I use the CQL pertained model and train either with CQL or SAC on top of that, I get back to that error. Feels like I have some really weird bug somewhere

takuseno commented 2 years ago

Hmm, seems that you're using ConstantEpsilonGreedy? It's only usable with discrete action algorithms (e.g. DQN). Please don't use it with continuous control environments. Also, if you use SAC, you don't need to specify explorers since the policy is already stochastic. Sounds like the issue would be gone if we don't use any explorers?

lettersfromfelix commented 2 years ago

Well you're right, should have thought about SAC not needing any explorer. Thanks a lot again! But just to understand, shouldn't this make no difference in theory? I mean isn't the explorer just adding noise or using random actions based on the original valid actions from SAC, and therefore shouldn't result in any NaN? Plus: Is there any way to show you my gratitude for helping? Do you happen to have a buymeacoffee link or something?

takuseno commented 2 years ago

Glad to hear it works on your end! The use of ConstantEpsilonGreedy explorer in continuous control environment will introduce invalid tensors since the output of ConstantEpsilonGreedy is one-hot vectors for discrete actions, which could make the algorithm fail training and the model would produce NaN values (not from observation or action).

Regarding any kinds of sponsorships, thank you for your kind offer. Currently, there aren't any. in this repository. Instead, it'd be great if you press GitHub star button of this repository :smile:

lettersfromfelix commented 2 years ago

That makes sense, thanks a lot! And surely pressed the start button :)