koaning / embetter

just a bunch of useful embeddings
https://koaning.github.io/embetter/
MIT License
465 stars 15 forks source link

Expose encode parameters in fit function #51

Closed ChenghaoMou closed 1 year ago

ChenghaoMou commented 1 year ago

https://github.com/koaning/embetter/blob/257c076daaaa438b7ce813aa155fb89ba5985451/embetter/text/_sbert.py#L77

Would it be possible to expose encode parameters in fit?

I can help create a PR if needed.

koaning commented 1 year ago

In the scikit-learn API it's more normal to pass those to the __init__ of the object. But I agree that progress bars and other parameters might be welcome.

To help me understand the feature request; which parameters are you most interested in and why?

ChenghaoMou commented 1 year ago

Maybe it is easier to use **kwargs in the fit function. In the pipeline fit function, we can then pass parameters like encoder__batch_size=64 CleanShot 2023-03-15 at 17 40 25

koaning commented 1 year ago

That's meant for things like sample_weight where the param also depends on X. The changes that you're suggesting here don't rely on X at all and therefor seem better to pass initially to the __init__ of the class.

Also, all of the objects in this library have fit() implement a noop. All these embeddings are already trained, so no actually fitting/training is taking place.

ChenghaoMou commented 1 year ago

The reason I am suggesting putting it in fit is when you look at sentence transformers, those parameters are exposed in the encode function instead of their __init__ function. Those parameters like batch size should be flexible during inference as well to leverage resources better.

With that being said, I am happy as long as they are exposed, in __init__, fit or predict. I can change them post initialization by directly assigning values.

Thank you kindly for this library, I will leave the decision to you.

koaning commented 1 year ago

I'll add some params to the __init__, not sure when, but I agree it'd be useful.

Note that none of these components in this library perform predict. The sklearn API dictates that these all are .transform() operations.