lightonai / pylate

Late Interaction Models Training & Retrieval
https://lightonai.github.io/pylate/
MIT License
175 stars 7 forks source link

Loading logic rework #52

Closed NohTow closed 2 months ago

NohTow commented 2 months ago

This PR solves some of the issues related to model loading.

First, it removes the model_kwargs add_pooling_layer that prevented the instantiation of a pooler within BERT models but is not common across all encoders (see #51). I considered removing the pooler after initialization, but then the saved PyLate model does not have weights for it and it yields a warning saying that those weights are not properly loaded from checkpoint. Although it does not matter as we are not using it anyways, this message can be misleading. Thus, I choose to let it be, as we are using the sequence_embeddings and not the pooled output, it's a small additional useless computation but I did not find a better solution.

Second, it adds a function to support loading a model created using the stanford-nlp library. This has two benefits:

  1. Every ColBERT model (with a base model loadable using ST) is now natively compatible with PyLate, without having to convert it manually. This should greatly enhance the number of compatible models (#50).
  2. Besides not having to convert it, it also means that we do not have to add the PyLate files to an existing stanford-nlp repository, as we did for Colbert-small. Besides not duplicating the weights, it solves this issue where the Transformer (from ST) folder was not at the root but in a subfolder, which resulted in the model configuration not being properly loaded and thus not properly loading the model to a specified dtype (#49).

Also took the opportunity to add dtype casting to the Dense layer to match the Transformer.

NohTow commented 2 months ago

Added a docstring and set the version of ST to pre 3.1 as it introduces breaking changes (I already have some fixes but need more tests + will be better in a dedicated MR). Also fixed an issue for repository with both PyLate and stanford weights where 2 dense layer were loaded.