neurostatslab / vocalocator

Deep neural networks for sound source localization and vocalization attribution.
MIT License
2 stars 0 forks source link

Create unified way to load model weights #29

Closed amanchoudhri closed 1 year ago

amanchoudhri commented 1 year ago

Currently the GerbilizerEnsemble class loads weights in the constructor, whereas all other architecture types have a separate builtin method load_state_dict.

To simplify the code and not have to check whether a model is an ensemble every time we load, I recommend creating one method load_weights of the GerbilizerArchitecture class, which takes in the two kwargs weights_path and use_final_weights. The base method will raise an error if weights_path isn't provided, but the GerbilizerEnsemble subclass can override the method and provide the correct behavior. Then whenever we need to load model weights, we can just do it like:

weights_path = '/some/path/to/weights.pt' # or None if the model is an ensemble
use_final = True # or False!
model.load_weights(weights_path=weights_path, use_final_weights=use_final)
amanchoudhri commented 1 year ago

Better yet, rather than raising an error in the base class's implementation, check the config for a 'WEIGHTS_PATH' flag and only throw an error if it can't be found.