baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
854 stars 84 forks source link

torchmetrics Accuracy() fails if `get_metrics()` is called before `test_on_dataset` #272

Open arthur-thuy opened 10 months ago

arthur-thuy commented 10 months ago

Describe the bug The torchmetrics Accuracy() class returns an error RuntimeError: You have to have determined mode. if wrapper.get_metrics() is called before wrapper.test_on_dataset, or if wrapper.test_on_dataset is not called at all.

In contrast, Baal's Accuracy() class handles this by returning 'test_accuracy': nan.

To Reproduce In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch (no acquisitions).

The script uses Baal's Accuracy() class as standard, and adds torchmetrics Accuracy() class with the option --torchmetrics. The script evaluates on the test set as standard, and omits this with the option --no-test.

Running python baal_error_torchmetrics.py --no-test:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:28:32.092743Z [info     ] Starting training              dataset=100 epoch=1
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:28:33.700075Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
{'dataset_size': 100,
 'test_accuracy': nan,
 'test_loss': nan,
 'train_accuracy': 0.08999999612569809,
 'train_loss': 2.309847593307495}
Elapsed total time: 0:0:2

Running python baal_error_torchmetrics.py --torchmetrics:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:26:52.913297Z [info     ] Starting training              dataset=100 epoch=1
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:26:54.735623Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-09-08T08:26:54.736589Z [info     ] Starting evaluating            dataset=10000
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-09-08T08:26:57.792895Z [info     ] Evaluation complete            test_loss=2.2263691425323486
{'dataset_size': 100,
 'test_accuracy': 0.19924747943878174,
 'test_loss': 2.2263691425323486,
 'test_torch_accuracy': 0.19900000095367432,
 'train_accuracy': 0.08999999612569809,
 'train_loss': 2.309847593307495,
 'train_torch_accuracy': 0.09000000357627869}
Elapsed total time: 0:0:6

Running python baal_error_torchmetrics.py --torchmetrics --no-test:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:27:53.001640Z [info     ] Starting training              dataset=100 epoch=1
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:27:54.803111Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
Traceback (most recent call last):
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 206, in <module>
    main()
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 115, in main
    pprint(wrapper.get_metrics())
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 71, in get_metrics
    metrics = {
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 72, in <dictcomp>
    met_name: get_value(met) for met_name, met in self.metrics.items() if filter in met_name
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 66, in get_value
    val = met.compute().detach().cpu().numpy()
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/metric.py", line 531, in wrapped_func
    value = compute(*args, **kwargs)
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 266, in compute
    raise RuntimeError("You have to have determined mode.")
RuntimeError: You have to have determined mode.

Expected behavior The torchmetrics Accuracy() class should also return 'test_torch_accuracy': nan, just like Baal's Accuracy() class.

Version:

Additional context /

arthur-thuy commented 10 months ago

I found a way to circumvent this problem in my code. When I want the training metrics before running test_on_dataset, I can just do get_metrics("train") and torchmetrics won't throw an error.

Dref360 commented 10 months ago

Thank you for submitting this issue!

I'll take a look more deeply over the weekend, I don't have a super good idea how to fix this right now unfortunately.