JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.29k stars 100 forks source link

loading checkpoints for using as a huggingface model #9

Closed itay-nakash closed 1 year ago

itay-nakash commented 1 year ago

Hello!

I'm trying to use a model that was pre-trained using cramming as a huggingface model (using AutoModel.from_pretrained(PATH_TO_MODEL). The transformers library needs model.bin file instead of the model.pth format the save_final_model() func creates currently.

Is there a suggested way to convert the files easily or to be able to use the checkpoints as a 'huggingface' model? thanks!

JonasGeiping commented 1 year ago

Hi, thanks for bringing this up. It is now possible to upload/download checkpoints to the huggingface hub. To do so, you can set impl.push_to_huggingface_hub=True and set the repository adress at impl.hf_directoy_name.

Feel free to re-open this issue if any questions remain!

itay-nakash commented 1 year ago

Hi Jonas, First of all thanks for adding the support in HF! Thats super helpful :)

I tried loading both the model you uploaded and a model I trained using the load_local_model.py script, and both gave me errors while trying to use them. Firstly, it seems like the tokenizer in "JonasGeiping/crammed-bert" is not compatible with the model ( I get the error 'TypeError: ScriptableLMForPreTraining.forward() got an unexpected keyword argument 'token_type_ids').

But after removing this argument, I get an error about the types of tensors that the model has, compared to the types the libraries in flash-attention are expecting: File "/data/home/itay.nakash/miniconda3/envs/nlp/lib/python3.10/site-packages/flash_attn/flash_attention.py", line 36, in forward assert qkv.dtype in [torch.float16, torch.bfloat16] AssertionError

It seems like it gives me the same error with models that I saved, and I think it is related to the types that HF automatically used, compared to the more 'efficient' types that flash-attention is expecting (?)

JonasGeiping commented 1 year ago

Sorry, the c5 checkpoint can only be loaded with cuda+amp, the documentation was not very explicit there. This can be done like so

model.cuda()
cuda_input = {k:i.cuda() for k,i in encoded_input.items()}

with torch.autocast("cuda"):
    output = model(**cuda_input)

Also, token_type_ids should not be given to the model, I should have removed those from the encoded text. I have modified the code so that now the model accepts the argument, but does not use it.

JonasGeiping commented 1 year ago

Hi, does this solve the issue?

JonasGeiping commented 1 year ago

Closing for now.

itay-nakash commented 1 year ago

Yes, this solves the issue. (sorry for the late reply) Thanks!

JonasGeiping commented 1 year ago

The new checkpoint from the updated codebase can now be loaded without all of these issues :)