BUTSpeechFIT / VBx

Variational Bayes HMM over x-vectors diarization
252 stars 57 forks source link

How to load the ResNet model in pytorch #55

Closed anuragkumar95 closed 1 year ago

anuragkumar95 commented 1 year ago

While it is much easier to work with onnx backend, I would like to run the ResNet embedding model on the GPU using pytorch. There seems to be multiple pytorch checkpoints.. what should be the correct path to the model checkpoint for pytorch load? I see that there are are 5 .pth checkpoints and 1 .onnx checkpoint.

-a----         1/12/2023   1:45 PM       60669003 final.onnx
-a----         1/12/2023   1:45 PM       26214400 raw_81.pth.zip.partaa
-a----         1/12/2023   1:45 PM       26214400 raw_81.pth.zip.partab
-a----         1/12/2023   1:45 PM       26214400 raw_81.pth.zip.partac
-a----         1/12/2023   1:45 PM       26214400 raw_81.pth.zip.partad
-a----         1/12/2023   1:45 PM       21165076 raw_81.pth.zip.partae
fnlandini commented 1 year ago

Hi Anurag, The .pth file is split just because it would be too large to be uploaded to the repository. If you look here you will see that the model is unzipped like:

if [ ! -f $WEIGHTS_DIR/raw_81.pth ]; then
    cat $WEIGHTS_DIR/raw_81.pth.zip.part* > $WEIGHTS_DIR/unsplit_raw_81.pth.zip
        unzip $WEIGHTS_DIR/unsplit_raw_81.pth.zip -d $WEIGHTS_DIR/
fi

I hope this helps

anuragkumar95 commented 1 year ago

Thanks a lot! This helps!