takuseno / d3rlpy

An offline deep reinforcement learning library
https://takuseno.github.io/d3rlpy
MIT License
1.32k stars 243 forks source link

[BUG] Q-learning algorithm results differ from version v1.1.0 to v2.2.0 #347

Closed jdesman1 closed 1 year ago

jdesman1 commented 1 year ago

Describe the bug I am observing significant variation between version 1.1.0 and version 2.2.0 in all aspects of Q learning model performance. The only before and after results I have are DiscreteBCQ and DiscreteCQL. The loss, td_error, value scale, etc. explode when training these models with the default hyperparameters in version 2+ compared to the much nicer convergence present in my experiments from version 1.1.0. Additionally, I have previously run experiments where I change the seeds in numpy, torch, etc. to ensure it isn't a lucky seed. I have also built v2.2.0 from source to ensure I have the latest bugfix relevant to #346.

Have there been changes in the optimization scheme or default hyperparameters?

To Reproduce I am using proprietary datasets that I cannot provide publicly. However, the datasets have not been altered since training the models between versions, and downgrading to version 1.1.0 fixes the convergence issues.

Expected behavior Identical/similar training across versions with the same dataset.

Additional context I am happy to provide additional details.

takuseno commented 1 year ago

@jdesman1 Thank you for reporting this! I've started an investigation. Based on the result, I'll release an emergency patch to fix the issue.

takuseno commented 1 year ago

I'm using this reproduction script to benchmark DiscreteCQL with Atari 2600 datasets. However, v2.2.0 achieves nearly the same result as v1.1.1 recorded here. This indicates that the algorithmic implementation has not been changed. If you can provide more information, that could help the investigation.

jdesman1 commented 1 year ago

@takuseno have the default arguments to these algorithms changed?

takuseno commented 1 year ago

Hmm, it's not changed. Looking at loss metrics, v2.2.0 also has the same trends as v1.1.0. Just to clarify, the default parameters mean the parameters used in this script? https://github.com/takuseno/d3rlpy/blob/f9bde319d7aeb3f2f5cba983833ed441f19fe0c6/reproductions/offline/discrete_cql.py#L24

jdesman1 commented 1 year ago

@takuseno I actually meant the default parameters equivalent to when instantiating via [AlgorithmName]Config() relative to the old [Algorithm]() syntax. After migrating to v2.2.0, I see exploding loss, TD error, etc. In fact, the value scales even explode beyond what the maximum possible reward can be.

takuseno commented 1 year ago

I did a quick benchmark to compare loss values.

v2.2.0

import d3rlpy

dataset, env = d3rlpy.datasets.get_cartpole()

cql = d3rlpy.algos.DiscreteCQLConfig().create()

cql.fit(
    dataset,
    n_steps=5000,
    n_steps_per_epoch=100,
    evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
)

v1.1.1

import d3rlpy

dataset, env = d3rlpy.datasets.get_cartpole()

cql = d3rlpy.algos.DiscreteCQL()

cql.fit(
    dataset,
    eval_episodes=dataset.episodes,
    n_steps=5000,
    n_steps_per_epoch=100,
    scorers={"environment": d3rlpy.metrics.evaluate_on_environment(env)},
)

This is a plot of loss metrics (blue=v2.2.0, orange=v1.1.1): image

The result indicates that those two versions behave similarly.

On a side note, if you use timeouts with your dataset in v2.2.0, it's slightly different from episode_terminals in v1.1.1. That could alternate terminal conditions to lead Q-function exploding. https://github.com/takuseno/d3rlpy/blob/f9bde319d7aeb3f2f5cba983833ed441f19fe0c6/d3rlpy/dataset/compat.py#L46

jdesman1 commented 1 year ago

Hm, I don't use the timeouts.

I've been able to reproduce this figure on my end with the cartpole data, which reassures me that my installation is ok. However, when using my own datasets, I find a different result. For example, with BCQ, my loss previously started around ~2.5 and decreases, while my loss now starts around ~5.5 and increases. I recognize that being unable to provide my dataset publicly makes things challenging, unfortunately.

In the meantime while we investigate, given that the algorithms should provide the same result, is it safe to use v1.1.1 instead of v2.2.0?

takuseno commented 1 year ago

Yes, v1.1.1 is still stable. You can keep using it until the issue is resolved.