mlcommons / GaNDLF

A generalizable application framework for segmentation, regression, and classification using PyTorch
https://gandlf.org
Apache License 2.0
150 stars 78 forks source link

[BUG] dice_per_label metric fails #883

Closed VukW closed 1 month ago

VukW commented 1 month ago

Describe the bug

https://github.com/mlcommons/GaNDLF/pull/868#issuecomment-2148174703 In some cases the dice_per_label metric fails. It looks like it returns a 3-len array while 4 classes expected:

Looping over training data:   0%|          | 0/6255 [02:51<?, ?it/s]
ERROR: Traceback (most recent call last):
  File "/net/tscratch/people/plgmazurekagh/gandlf/gandlf_env/bin/gandlf_run", line 126, in <module>
    main_run(
  File "/net/tscratch/people/plgmazurekagh/gandlf/gandlf_env/lib/python3.10/site-packages/GANDLF/cli/main_run.py", line 92, in main_run
    TrainingManager_split(
  File "/net/tscratch/people/plgmazurekagh/gandlf/gandlf_env/lib/python3.10/site-packages/GANDLF/training_manager.py", line 173, in TrainingManager_split
    training_loop(
  File "/net/tscratch/people/plgmazurekagh/gandlf/gandlf_env/lib/python3.10/site-packages/GANDLF/compute/training_loop.py", line 445, in training_loop
    epoch_train_loss, epoch_train_metric = train_network(
  File "/net/tscratch/people/plgmazurekagh/gandlf/gandlf_env/lib/python3.10/site-packages/GANDLF/compute/training_loop.py", line 171, in train_network
    total_epoch_train_metric[metric] += metric_val
ValueError: operands could not be broadcast together with shapes (4,) (3,) (4,) 

To Reproduce

Steps to reproduce the behavior:

  1. Use the data BRaTS 2021
  2. Start training the model with dice_per_label metrics mentioned in the config

Additional context

It's still not clear if the bug persists in main branch or was caused by metrics fix https://github.com/mlcommons/GaNDLF/pull/868

VukW commented 1 month ago

Closed in https://github.com/mlcommons/GaNDLF/pull/868/commits/d0d25fbbc91d7f3ae3235a0b3e80ada65d1f8787 https://github.com/mlcommons/GaNDLF/pull/868