baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
854 stars 84 forks source link

Argument in `test_on_dataset` and `train_and_test_on_datasets` functions to write "val_" metrics instead of "test_" #266

Open arthur-thuy opened 1 year ago

arthur-thuy commented 1 year ago

Is your feature request related to a problem? Please describe. The MetricMixin class only creates "train" and "test" metrics in the add_metric method. This works fine when only using a training and test set.

However, when also using a validation set such as in the snippets below, this presents a problem.

for al_step in range(N_ALSTEP):
    _ = wrapper.train_on_dataset(
        train_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(val_dataset, BATCH_SIZE)
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break
for al_step in range(N_ALSTEP):
    _ = wrapper.train_and_test_on_datasets(
        train_dataset, val_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break

Here, the true validation metrics are recorded as "test" and are later overwritten by the true test metrics also recorded in "test".

Describe the solution you'd like It would be nice if the test_on_dataset and train_and_test_on_datasets functions have an argument to specify which metric is written ("val" or "test").

Describe alternatives you've considered A simple but cumbersome solution is to create a dict and copy all the "test" metrics corresponding to the true validation metrics in the dict as "val", as follows:

trainval_hist = wrapper.train_and_test_on_datasets(...)
trainval_last = trainval_hist[-1]  # NOTE: take log at last epoch
metrics[len(active_set)] = {
    "train_loss": trainval_last["train_loss"],
    "train_accuracy": trainval_last["train_accuracy"],
    "dataset_size": len(active_set),
    "epochs_trained": len(trainval_hist),
    "val_loss": trainval_last["test_loss"],
    "val_accuracy": trainval_last["test_accuracy"],
}

Additional context /

Dref360 commented 1 year ago

That make sense! Something like

wrapper.train_and_test_on_datasets(eval_set='val')?

For backward compatibility, we would still keep test as the default.

What do you think?

arthur-thuy commented 1 year ago

That would be a good solution in my opinion!