Currently the GerbilizerEnsemble class loads weights in the constructor, whereas all other architecture types have a separate builtin method load_state_dict.
To simplify the code and not have to check whether a model is an ensemble every time we load, I recommend creating one method load_weights of the GerbilizerArchitecture class, which takes in the two kwargs weights_path and use_final_weights. The base method will raise an error if weights_path isn't provided, but the GerbilizerEnsemble subclass can override the method and provide the correct behavior. Then whenever we need to load model weights, we can just do it like:
weights_path = '/some/path/to/weights.pt' # or None if the model is an ensemble
use_final = True # or False!
model.load_weights(weights_path=weights_path, use_final_weights=use_final)
Better yet, rather than raising an error in the base class's implementation, check the config for a 'WEIGHTS_PATH' flag and only throw an error if it can't be found.
Currently the
GerbilizerEnsemble
class loads weights in the constructor, whereas all other architecture types have a separate builtin methodload_state_dict
.To simplify the code and not have to check whether a model is an ensemble every time we load, I recommend creating one method
load_weights
of theGerbilizerArchitecture
class, which takes in the two kwargsweights_path
anduse_final_weights
. The base method will raise an error ifweights_path
isn't provided, but theGerbilizerEnsemble
subclass can override the method and provide the correct behavior. Then whenever we need to load model weights, we can just do it like: