jankrepl / deepdow

Portfolio optimization with deep learning.
https://deepdow.readthedocs.io
Apache License 2.0
875 stars 136 forks source link

what does batch in history.metrics represent #82

Closed turmeric-blend closed 3 years ago

turmeric-blend commented 3 years ago

hi, i'm looking at the getting_started.ipynb notebook and at the Evaluation and Visualization section I tried displaying history.metrics. I noticed that the batch column has only array([0, 1, 2, 3]) batch for every epoch, this code confirms it:

import pandas as pd
pd.unique(history.metrics['batch'])

May I ask what this batch column represents?

At first I thought it was number of forward passes per batch_size, however, the synthetic data generated in that notebook with 750 training samples and training batch_size of 32 would give 23 forward passes altogether (in this case batch column would be array([0, 1, 2, 3, ... , 21, 22]) per epoch).

Thanks.

jankrepl commented 3 years ago

Great question!

The history.metrics table only stores statistics coming from dataloaders that are provided via the val_dataloaders parameter:

run = Run(network,
          loss,
          dataloader_train,
          val_dataloaders={'test': dataloader_test},  # <------------- HERE
          optimizer=torch.optim.Adam(network.parameters(), amsgrad=True),
          callbacks=[EarlyStoppingCallback(metric_name='loss',
                                           dataloader_name='test',
                                           patience=15)])

In the getting_started example it is the dataloader_test. Since the test set has ~120 samples one only gets 4 batches.

The logic of what is put inside of the history.metrics is determined in the ValidationCallback. See on_epoch_end implementation: https://github.com/jankrepl/deepdow/blob/f123a8212ab3b6a2ff8514dfd8b936bc47a8806d/deepdow/callbacks.py#L691

Note that if you want you can also put the training set inside of the val_dataloaders to get detail statistics at the end of each epoch.

turmeric-blend commented 3 years ago

ah of course ! thanks :)