sgoldenlab / simba

SimBA (Simple Behavioral Analysis), a pipeline and GUI for developing supervised behavioral classifiers
https://simba-uw-tf-dev.readthedocs.io/
GNU General Public License v3.0
289 stars 141 forks source link

`KeyError` when training multiple models from meta files using custom weights #248

Closed florianduclot closed 1 year ago

florianduclot commented 1 year ago

Describe the bug A KeyError is thrown when training multiple models from meta files using custom weights. The meta file contains:

The following error is thrown:

Exception in Tkinter callback
Traceback (most recent call last):
  File "C:\Users\username\miniconda3\envs\simbadev\lib\tkinter\__init__.py", line 1705, in __call__
    return self.func(*args)
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\SimBA.py", line 363, in <lambda>
    button_train_multimodel = Button(label_trainmachinemodel, text='TRAIN MULTIPLE MODELS (ONE FOR EACH SAVED SETTING)',fg='green',command = lambda: threading.Thread(target=self.train_multiple_models_from_meta(config_path=self.config_path)).start())
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\SimBA.py", line 593, in train_multiple_models_from_meta
    model_trainer.run()
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\train_mutiple_models.py", line 150, in run
    self.meta_dicts = self.__check_validity_of_meta_files(meta_file_paths=self.meta_file_lst)
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\train_mutiple_models.py", line 125, in __check_validity_of_meta_files
    meta_dict[ReadConfig.CLASS_WEIGHTS.value] = meta_dict['custom_weights']
KeyError: 'custom_weights'

To Reproduce Steps to reproduce the behavior: Save a model settings meta file with custom weights using the defaults weights: The meta file contains:

Click on the green "Train multiple models (one for each saved setting)" button; the annotations are loaded, and then the error above is thrown.

Expected behavior The model training proceeds without error.

Desktop (please complete the following information):

Additional context This seems to result from the fact that __check_validity_of_meta_files() expects the meta file to have the column name custom_weights whereas the actual name is class_custom_weights. Manually changing this in train_multiple_models.py fixes the issue but quickly fails again with:

Exception in Tkinter callback
Traceback (most recent call last):
  File "C:\Users\username\miniconda3\envs\simbadev\lib\tkinter\__init__.py", line 1705, in __call__
    return self.func(*args)
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\SimBA.py", line 363, in <lambda>
    button_train_multimodel = Button(label_trainmachinemodel, text='TRAIN MULTIPLE MODELS (ONE FOR EACH SAVED SETTING)',fg='green',command = lambda: threading.Thread(target=self.train_multiple_models_from_meta(config_path=self.config_path)).start())
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\SimBA.py", line 593, in train_multiple_models_from_meta
    model_trainer.run()
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\train_mutiple_models.py", line 148, in run
    self.meta_dicts = self.__check_validity_of_meta_files(meta_file_paths=self.meta_file_lst)
  File "C:\Users\username\miniconda3\envs\simbadev\lib\site-packages\simba\train_mutiple_models.py", line 124, in __check_validity_of_meta_files
    for k, v in meta_dict[ReadConfig.CLASS_WEIGHTS.value].items():
AttributeError: 'str' object has no attribute 'items'

This shows that the routine fails to deal with the dict being read as str from the file. I could fix all of that using the same approach used elsewhere in SimBa:

                if meta_dict[ReadConfig.CLASS_WEIGHTS.value] == 'custom':
                    weights = ast.literal_eval(meta_dict['class_custom_weights'])
                    meta_dict[ReadConfig.CLASS_WEIGHTS.value] = weights
                    for k, v in meta_dict[ReadConfig.CLASS_WEIGHTS.value].items():
                        meta_dict[ReadConfig.CLASS_WEIGHTS.value][k] = int(v)

Note that it seems one could/should update ReadConfig.CUSTOM_WEIGHTS to be class_custom_weights or revert change the meta file to have custom_weights as column name but that is beyond my knowledge of SimBA to know which way to address this would be the most reliable and safe.

sronilsson commented 1 year ago

Thanks for this @florianduclot !

I can see it, first as you say there is a typo in the key it should be class_custom_weights and not custom_weights. Second, the config meta file stores it as a str so we have to convert it to a dict. Can you try if this fixes it?

At the top, import literal_eval: from ast import literal_eval

Change routine to:

                if meta_dict[ReadConfig.CLASS_WEIGHTS.value] == 'custom':
                    meta_dict[ReadConfig.CLASS_WEIGHTS.value] = literal_eval(meta_dict['class_custom_weights'])
                    for k, v in meta_dict[ReadConfig.CLASS_WEIGHTS.value].items():
                        meta_dict[ReadConfig.CLASS_WEIGHTS.value][k] = int(v)

Let me know how goes and I will update pip package.

Simon

florianduclot commented 1 year ago

Thanks for the quick feedback, @sronilsson ,

I can't test that right at this time but it's equivalent to what I've tested to work (see my last code snippet above) so I agree with you that it should work.

I'll try your exact solution when I get a chance but it might not be today, unfortunately.

sronilsson commented 1 year ago

Hi @florianduclot - I did push the updated code and just a heads up up if you try it, I re-organized the files a fair bit to make more readable - the code piece lives in simba.model.grid_search_rf now.

florianduclot commented 1 year ago

Just wanted to confirm: I tried the latest Simba version yesterday and it indeed works as intended.

Thanks again for pushing the fix.