UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.21k stars 2.47k forks source link

default_activation_function is not imported from string #2385

Open gsakkis opened 10 months ago

gsakkis commented 10 months ago

Based on this code it seems that the intention is to support importing a default_activation_function from a string but it just sets self.config.sbert_ce_default_activation_function; the part that actually imports it is in the next elif statement so it doesn't get executed.

tomaarsen commented 10 months ago

Hello!

As far as I can tell, this code does work as expected. However, if you provide default_activation_function then it must be a Callable rather than a string (as the docstring mentions). So, there are three cases:

  1. The user provides the Callable, this sets self.default_activation_function. It also updates the config so the model stores this info when saved.
  2. The user does not override the Callable, but the model configuration has sbert_ce_default_activation_function because someone in the past overrode default_activation_function and then saved the model. So, we load from the configuration and set self.default_activation_function.
  3. The user does not override the Callable and the model has no activation function defined. So, we load the default and set self.default_activation_function.

So, it's not possible to provide a string, it must be a Callable. I hope this makes some sense!

gsakkis commented 10 months ago

I see, the It also updates the config so the model stores this info when saved was the part I was missing. So since it's not a bug, perhaps consider it a a feature request? Happy to submit a PR in this case!

tomaarsen commented 10 months ago

I'm a little hesitant to accept that feature request right now, I think. The string that is now used to track the activation function is e.g. "torch.nn.modules.activation.Sigmoid", but that is very much dependent on the torch version. If an activation function is ever moved in the pytorch source code, e.g. to torch.nn.modules.activation.sigmoid.Sigmoid or something, then this stops working.

I think I might eventually refactor this such that users can provide Sigmoid and it will try to map that to the torch Sigmoid class internally in SentenceTransformers, because although I'm not sure that Sigmoid will always stay at "torch.nn.modules.activation.Sigmoid", I'm pretty confident that it will always be possible to do from torch.nn import Sigmoid. Does that make sense?

gsakkis commented 10 months ago

I'm pretty confident that it will always be possible to do from torch.nn import Sigmoid

That's exactly how specify it as string in my yaml config:

  kwargs:
    model_name: cross-encoder/ms-marco-MiniLM-L-4-v2
    max_length: null
    tokenizer_args: {}
    automodel_args: {}
    default_activation_function: torch.nn.Sigmoid
tomaarsen commented 10 months ago

Oh, I see. That looks reasonable indeed. I was under the impression that only the "full path" worked.