vocalpy / vak

A neural network framework for researchers studying acoustic communication
https://vak.readthedocs.io
BSD 3-Clause "New" or "Revised" License
78 stars 16 forks source link

BUG: `vak.train_classification` throws KeyError when checking for dataset_config["params"]["target_type"] #779

Closed henricombrink closed 2 weeks ago

henricombrink commented 3 weeks ago

Description I am currently trying to go thought the vak tutorial and have completed the vak prep step successfully. I have changed the parameters in the .toml files and the [vak.train.dataset] get updated in the .toml file after the prep step. However, I am getting the following error when I try to run vak train gy6or6_train.toml :

vak train gy6or6_train_test.toml 
2024-10-17 15:00:07,142 - vak.cli.train - INFO - vak version: 1.0.2
2024-10-17 15:00:07,142 - vak.cli.train - INFO - Logging results to /home/henri/Downloads/vakex/training_output/results_241017_150007
2024-10-17 15:00:07,142 - vak.train.frame_classification - INFO - Will save results in `results_path`: /home/henri/Downloads/vakex/training_output/results_241017_150007
2024-10-17 15:00:07,142 - vak.train.frame_classification - INFO - Loading dataset from `dataset_path`: /home/henri/Downloads/vakex/training_output/training-vak-frame-classification-dataset-generated-241017_145340
Using dataset config: {'path': PosixPath('/home/henri/Downloads/vakex/training_output/training-vak-frame-classification-dataset-generated-241017_145340'), 'splits_path': None, 'name': None, 'params': {'window_size': 176}}
2024-10-17 15:00:07,144 - vak.train.frame_classification - INFO - Duration of a frame in dataset, in seconds: 0.002
2024-10-17 15:00:07,144 - vak.train.frame_classification - INFO - Using training split from dataset: /home/henri/Downloads/vakex/training_output/training-vak-frame-classification-dataset-generated-241017_145340
2024-10-17 15:00:07,144 - vak.train.frame_classification - INFO - Total duration of training split from dataset (in s): 56.292
2024-10-17 15:00:07,144 - vak.train.frame_classification - INFO - loading labelmap from path: /home/henri/Downloads/vakex/training_output/training-vak-frame-classification-dataset-generated-241017_145340/labelmap.json
2024-10-17 15:00:07,144 - vak.train.frame_classification - INFO - No `frames_standardizer_path` provided, not loading
2024-10-17 15:00:07,145 - vak.train.frame_classification - INFO - Will standardize (normalize) frames
2024-10-17 15:00:07,204 - vak.train.frame_classification - INFO - Duration of TrainDatapipe used for training, in seconds: 56.292
2024-10-17 15:00:07,204 - vak.train.frame_classification - INFO - Will measure error on validation set every 400 steps of training
2024-10-17 15:00:07,204 - vak.train.frame_classification - INFO - Using validation split from dataset:
/home/henri/Downloads/vakex/training_output/training-vak-frame-classification-dataset-generated-241017_145340
2024-10-17 15:00:07,205 - vak.train.frame_classification - INFO - Total duration of validation split from dataset (in s): 22.472
2024-10-17 15:00:07,206 - vak.train.frame_classification - INFO - Duration of InferDatapipe used for evaluation, in seconds: 22.472
2024-10-17 15:00:07,223 - vak.train.frame_classification - INFO - training TweetyNet
Traceback (most recent call last):
  File "/home/henri/anaconda3/bin/vak", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/__main__.py", line 49, in main
    cli.cli(command=args.command, config_file=args.configfile)
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/cli/cli.py", line 54, in cli
    COMMAND_FUNCTION_MAP[command](toml_path=config_file)
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/cli/cli.py", line 10, in train
    train(toml_path=toml_path)
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/cli/train.py", line 55, in train
    train_module.train(
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/train/train_.py", line 150, in train
    train_frame_classification_model(
  File "/home/henri/anaconda3/lib/python3.12/site-packages/vak/train/frame_classification.py", line 424, in train_frame_classification_model
    if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]):
                  ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
KeyError: 'target_type'

system Ubuntu 22.04.5 LTS vak 1.0.2

NickleDave commented 2 weeks ago

Hi @henricombrink I'm sorry I missed the notification for this last week.

Thank you for your detailed report.
I am able to reproduce this locally.

At first glance, it looks to me like I introduced a bug in version 1.0.2, but I haven't quite got to the bottom of it yet.

For now, can you please use version 1.0.1?
By doing pip install vak==1.0.1 or conda install vak==1.0.1 -c conda-forge in your environment.

I tested I was able to get the tutorial to work with 1.0.1 at least through the train step, and I am pretty sure from other users working through it that the rest should be working as well.

The only changes in version 1.0.2 are related to features I am adding to work with a benchmark datatset, so it will not impact how you work with your own data to use 1.0.1 instead of 1.0.2.

I have been working on related changes to vak and vocalpy. But I will figure out today or tomorrow the cause of this bug and I will put fixing it on the top of the to-do list.

Thank you for catching it, I am embarrassed I put out a version where the tutorial wasn't working :flushed: -- seems like I need to set up tests to catch that :thinking:

henricombrink commented 2 weeks ago

Thank you for the reply @NickleDave , much appreciated.

I installed uninstalled vak==1.0.2 and installed vak==1.0.1. The training scripts seems to be running now. If I encounter further problems I will report back here but for now the version change seems to have solved it.

Thank you for the help.

NickleDave commented 2 weeks ago

Great, glad to hear it, I will link back to the issue describing the bug here and keep you updated

Turns out gmail was sending my notifications for this repo to spam for some reason 😠 I will add it to the safelist

NickleDave commented 2 weeks ago

And @henricombrink please just let me know what else I can do to help.
Our software is mainly being used by neuro labs but the goal is for it to be more broadly useful

Looks like you're doing PAM work? (Did some Google stalking, hope that's ok.) I added you to the forum as well--please feel free to introduce yourself there if you have a chance.

NickleDave commented 2 weeks ago

Decided to not make a separate issue for this, instead just reworded the title to remind me what the source of the error is

We throw the KeyError here: https://github.com/vocalpy/vak/blob/7f8754c4b858a687da436348d9cc6bdcc81d78cc/src/vak/train/frame_classification.py#L424

Looks like this is the offending commit that introduced this bug: f40a3d420f5598dba61779b585b46af09d854184

What's happening here is we are setting up to call get_trainer; specifically we need to determine whether there are multiple target types, and if so, we need to more precisely specify the accuracy we are going to monitor for early stopping. But this currently only matters for the BioSoundSegBench dataset (soon to be named CMACBench); for user prep'd datasets, we always use a single target, multi-class frame labels.

If I run the unit tests on train.frame_classification, then I do trigger this bug with the very first unit test.
So this is really my fault for (1) not running tests locally before releasing, and (2) not having CI working to catch it either: I need to finish #736.

I think a quick fix is just to insert a check for the key before the logic that decides how many target types there are, like so:

    if "target_type" in dataset_config["params"]:
        if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]):
            multiple_targets = True
        elif isinstance(dataset_config["params"]["target_type"], str):
            multiple_targets = False
        else:
            raise ValueError(
                f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}'
            )
    else:
        multiple_targets = False

I made this fix in a quick-and-dirty and all tests pass, so I will go ahead and release a bugfix version to close this issue.

Long term I need to think about how to organize all this, it feels very kludgy. The reason is that we are not really committed to providing the ability to specify different target types, so we don't have a designated way to declare them. E.g., we could have a default target type of "multi_frame_labels" and then more directly infer which value to monitor from there. I will raise a separate issue about that

NickleDave commented 2 weeks ago

@henricombrink I just published version 1.0.3 to pypi that should fix this. A conda-forge package should follow shortly.

Sorry again for releasing with a trivial bug and for not getting the notification, and thank you for reporting the bug!

NickleDave commented 2 weeks ago

@all-contributors please add @henricombrink for bug

allcontributors[bot] commented 2 weeks ago

@NickleDave

I've put up a pull request to add @henricombrink! :tada:

henricombrink commented 2 weeks ago

@NickleDave Thanks for the quick fix, much appreciated. I will let you know if I encounter further problems.

I am working on red squirrels vocalizations. I have used BirdNet to annotate squirrel rattles from a large collection of sound recordings. I am hoping to use VAK to annotate individual syllables within rattle sequences.

NickleDave commented 2 weeks ago

Sound very cool, can't wait to hear what you all are learning about red squirrel vocalizations when you're ready to share.

Please just let me know whatever else I can do to help.