takuseno / d3rlpy

An offline deep reinforcement learning library
https://takuseno.github.io/d3rlpy
MIT License
1.34k stars 243 forks source link

[Question] Tracking validation loss during training #331

Closed spencerJ777 closed 1 year ago

spencerJ777 commented 1 year ago

Hi, is there a way to track the loss in the validation set during training? Any suggestion would be much appreciated.

joshuaspear commented 1 year ago

Hey - it's the same as tracking the training metrics using the evaluators from https://github.com/takuseno/d3rlpy/blob/master/d3rlpy/metrics/evaluators.py but when you initialise the evaluator, provide the validation episodes e.g.:

# Track validation TD error
val_td_scorer = d3rlpy.metrics.TDErrorEvaluator(episodes=val_episodes)
# Track training TD error
train_td_scorer = d3rlpy.metrics.TDErrorEvaluator()

cql.fit(
    dataset,
    n_steps=1000000,
    n_steps_per_epoch=10000,
    evaluators={"val_td_scorer": val_td_scorer, "train_td_scorer":train_td_scorer},
)

Let me know if that doesn't make sense :)

spencerJ777 commented 1 year ago

Hi, thanks for the response. I implemented NFQ and tried to plot the loss against the TD error on the training dataset but got completely different results. I'm just curious why this is the case since NFQ loss and TD error have about the same equation.

from d3rlpy.algos import NFQConfig
from d3rlpy.metrics import TDErrorEvaluator

algo = NFQConfig().create(0)

algo.fit(
    train, 
    n_steps=10000,
    n_steps_per_epoch=100,
    show_progress=False,
    save_interval=np.inf,
    evaluators={
        "train_td_scorer": TDErrorEvaluator(),
    },
)

df = pd.read_csv('d3rlpy_logs/NFQ_20230831204920/loss.csv', header=None)
df2 = pd.read_csv('d3rlpy_logs/NFQ_20230831204920/train_td_scorer.csv', header=None)

plt.plot(df[1], df[2])
plt.plot(df2[1], df2[2])
plt.legend(['loss', 'td_error'])

output f2[2])

joshuaspear commented 1 year ago

Hey - please can I confirm how big your dataset is, please? Also, are you randomising the transitions?

My thinking is: the TD error blowing up is actually indicating that your model is overfitting. You're right that the loss and TD error are the same calculation however, the loss is calculated on the batch whereas the TD error is calculated on the entire dataset. If your replay ratio (number of times your train on the same transition) is high I.e your dataset is << 10000 steps, then I would say that's the most likely explanation.

If you're not randomising transitions - that would explain why the TD error between sequential batches is small.

DQN might be better suited to your problem as having a target q network with update interval > 1 should prevent the Q network overfitting - I think...

I might be wrong! Let me know what you find though - interested to understand what's going on :)

Edited as a realised you'd already provided a load of questions that I asked and I suggested a rubbish idea!

takuseno commented 1 year ago

I spent time on debugging this. Now, I found out that huber loss makes difference. In DQN, we compute huber loss that calculates absolute error instead of squared error for values above 1.0. On the other hand, TDErrorEvaluator calculates squared errors for all transitions. To confirm this, I internally modified DQN to use squared error and saw the mostly the same extreme values in both training loss and validation loss.

takuseno commented 1 year ago

If you want, we can make compute_as_abs option at TDErrorEvaluator to save TD errors as absolute errors.

spencerJ777 commented 1 year ago

@joshuaspear Thanks for your explanation. My data is indeed very small (less than 1000 steps actually). It's a dataset recorded from different subjects so the transitions are quite noisy too. I only have an offline dataset so I tried plotting the results using DiscreteBCQ and it's not exploding like NFQ. output

spencerJ777 commented 1 year ago

@takuseno Thanks for your explanation too now it makes sense. I can easily make modifications to TDErrorEvaluator to use absolute error so you don't have to make it.

joshuaspear commented 1 year ago

@spencerJ777 glad you got it sorted :) I wouldn't 100% take what I said as the difference in loss/td was explained by the difference in Huber/MSE calc. Purely out of interest, did you try running just standard DQN? Also when you mentioned your transitions were noisey due to different subjects as in the within episode transitions are noisey?

This is purely for my own learning - still learning about how to train offline models! :)

spencerJ777 commented 1 year ago

@joshuaspear This is what the plot looks like when I use standard DQN. Also, I'm just curious can all the algorithms be used for offline RL? In the README.md I only saw a checkmark under offline RL (for discrete control) for NFQ, DiscreteBCQ, and DiscreteCQL so I just assumed the other unchecked ones are not supported for purely offline.

To your question, I probably didn't word it very well. I have N episodes and each of them is from a different subject. For each subject, P(s'|s,a) can be different. And so does P(s'|s,a) can be different in the beginning and end of the episode. This is the "noisy" that I'm trying to explain. Don't know if it makes sense. Mine is a problematic case since it's a small dataset and I can't validate the model on the environment. I'm also training it using leave-one-episode-out cross-validation to see if it can get good results when training on N-1 episodes and validating on the remaining one. So TD error might blow up sometimes.

output

takuseno commented 1 year ago

Offline training is supported by all algorithms in terms of interface. But only algorithms marked as offline RL support can learn a policy properly. Other algorithms will struggle with overestimation issue.

joshuaspear commented 1 year ago

@spencerJ777 thanks for sharing the DQN plots - I was just interested in what the difference was. If this is a project you can share and you’d like someone to bounce ideas off every now and then - I’d be happy to share my email.

Re your noise explanation - no, that does make complete sense. I’m not familiar with the cross validation approach of leaving just one sample out - I generally have more even splits and a fewer number of splits I.e 5-fold. But with respect to the loss your optimising against - you’re better off using OPE - I would recommend this as reading https://arxiv.org/abs/2005.01643. I’m also developing an OPE library with an API for d3rlpy - https://github.com/joshuaspear/offline_rl_ope. It currently only supports d3rlpy=1.x.x - I’m releasing an update for the latest d3rlpy version next week.