Jwoo5 / fairseq-signals

A collection of deep learning models for ECG data processing based on fairseq framework
Other
78 stars 12 forks source link

Issue with build_model checkpoint loading. #34

Closed KadenMc closed 2 months ago

KadenMc commented 2 months ago

I was testing out the fairseq_signals.models.build_model function with a pretrained model (Wav2Vec2CMSCModel) and realized that it's always loading the model with random weights, even with from_checkpoint=True.

This seems to be the case regardless of what information is in config, e.g., setting checkpoint.restore_file, so I think this function just isn't supporting the checkpoint loading as expected - at least for this model configuration.

Jwoo5 commented 2 months ago

I'm looking into that functionality, expecting it would be solved soon.

Jwoo5 commented 2 months ago

Okay, it was actually not a bug or an unintended behavior. To be clear, I want to specify that from_checkpoint=True in build_model didn't mean that it loads the model weights from the checkpoint. Instead, it is used when we want to remove missing config entries before loading the model checkpoint. However, I think adding functionality to load model weights when setting from_checkpoint=True in build_model would be a good idea.

Jwoo5 commented 2 months ago

So, if you want to load the whole model weights (not a part of pretrained modules) to restore and continue the training with fairseq-signals framework, you can set checkpoint.restore_file in the config to the absolute path for the checkpoint to load. It will load other states of trainer modules such as optimizer, lr scheduler, meters, etc, and continue the training.

Or if you need to load only the model weights from the checkpoint, it can simply done by this:

from fairseq_signals.models import build_model_from_checkpoint

model = build_model_from_checkpoint(
    checkpoint_path="/path/to/checkpoint.pt"
)

Tagging @KadenMc to track this issue. If all of your concerns have been addressed, please feel free to close this issue!