linto-ai / whisper-timestamped

Multilingual Automatic Speech Recognition with word-level timestamps and confidence
GNU Affero General Public License v3.0
1.86k stars 149 forks source link

Loading finetuned model serialized with safetensors (and/or sharded models) #160

Closed BlahBlah314 closed 8 months ago

BlahBlah314 commented 8 months ago

I've got a finetuned whisper model on hugging face and I would like to try timestamped and diarization on it. But when loading it, it searches for a pytorch_model.bin, but my model doesn't have a pytorch.bin but has safetensors. How can I resolve this issue ?

Jeronymous commented 8 months ago

Can you give some code to reproduce?

LaurinmyReha commented 8 months ago
from transformers import WhisperForConditionalGeneration
model=WhisperForConditionalGeneration.from_pretrained("your_model_dir")
model.save_pretrained("your_model_dir", safe_serialization=False, max_shard_size= '10GB') 
Jeronymous commented 8 months ago

What is "your_model_dir" in your case? You mentioned "finetuned whisper model on hugging face", I thought you were using a model that is on Hugging Face.

Also there is no whisper_timestamped in your code. Is it possible to have a code to reproduce how you load the model in whisper_timestamped? (the code that throws an error related to a missing "pytorch_model.bin" file I guess)

Jeronymous commented 8 months ago

I think I found a way to reproduce:

import shutil
from transformers import WhisperForConditionalGeneration
import whisper_timestamped as whisper

audio_file = "XXX.wav" # use an audio file here

shutil.rmtree("tmp_model", ignore_errors=True)
model = WhisperForConditionalGeneration.from_pretrained("qanastek/whisper-tiny-french-cased")
model.save_pretrained("tmp_model", safe_serialization=False, max_shard_size= '100MB')

model = whisper.load_model("qanastek/whisper-tiny-french-cased")
expected = whisper.transcribe(model, audio_file)

model = whisper.load_model("tmp_model")
output = whisper.transcribe(model, audio_file)

assert expected == output

The second loading of the model fails (whisper.load_model("tmp_model")). It happens because the model is sharded. I will investigate that.

@LaurinmyReha Can you just check that your folder "your_model_dir" contains something similar to this:

config.json  generation_config.json  pytorch_model-00001-of-00003.bin  pytorch_model-00002-of-00003.bin  pytorch_model-00003-of-00003.bin  pytorch_model.bin.index.json
LaurinmyReha commented 8 months ago

exactly. Increase the max_shard_size to something larger and you should be fine i think. 10GB as in the example will definately be enough :)

model.save_pretrained("your_model_dir", safe_serialization=False, max_shard_size= '10GB')

BlahBlah314 commented 8 months ago

Can you give some code to reproduce?

Hi ! Here is my code:

import whisper_timestamped as whisper

audio = whisper.load_audio("4d8b691a-7529-47d1-a3e3-00ce32f430c2.wav")

model = whisper.load_model("BlahBlah314/whisper_LargeV3FR_ft-V1", device="cuda")

result = whisper.transcribe(model, audio, language="fr")

I must specify that I don't have a .bin model, but a safetensor. Therefore, maybe I need to convert somehow my safetensors model ? Or a config to specify on whisper timestamped ?

LaurinmyReha commented 8 months ago
from transformers import WhisperForConditionalGeneration
model=WhisperForConditionalGeneration.from_pretrained("BlahBlah314/whisper_LargeV3FR_ft-V1")
model.save_pretrained("path_to_where_you_want_to_safe_your_bin", safe_serialization=False, max_shard_size= '10GB') 

exactly. I think this should do the conversion.... there is probably a better way but this seemed easiest to me. After that add the .bin file to where your safetensor file resides ( in your case the huggingface hub) and you should be able to load the model :)

BlahBlah314 commented 8 months ago

Thank you ! I'll try that way

Jeronymous commented 8 months ago

Thanks both ! I figured out how to support shared models and safetensors format.

@BlahBlah314 You should be able to use your "BlahBlah314/whisper_LargeV3FR_ft-V1" with the new version (1.14.4)

LaurinmyReha commented 8 months ago

very cool!! thanks for the quick update!