scikit-learn-contrib / hiclass

A python library for hierarchical classification compatible with scikit-learn
BSD 3-Clause "New" or "Revised" License
113 stars 20 forks source link

BERT memory leak #125

Open lpfgarcia opened 3 months ago

lpfgarcia commented 3 months ago

Hello,

First, I would like to thank the developers for the experience of using hiclass. The library is very well developed, and the documentation is very comprehensive. I have two comments: one is a suggestion, and the other is a possible bug.

The suggestion is simple: could you include the FlatClassifier as a method? I saw that in some example notebooks. This would help users compare the different strategies.

The second comment relates to the BERT. Unfortunately, bert_sklearn keeps the models in memory (GPU or CPU). This makes it impractical to use for hierarchies of almost any size. Could you consider saving the models during the Hiclass fit stage and loading them during the predict stage?

Thank you!

mirand863 commented 3 months ago

Hi @lpfgarcia,

Thank you for the interest in HiClass.

The suggestion is simple: could you include the FlatClassifier as a method? I saw that in some example notebooks. This would help users compare the different strategies.

I believe this would be out of scope for the library, since its purpose is to implement local hierarchical classifiers. Besides, wouldn't this be just an import of scikit-learn models?

The second comment relates to the BERT. Unfortunately, bert_sklearn keeps the models in memory (GPU or CPU). This makes it impractical to use for hierarchies of almost any size. Could you consider saving the models during the Hiclass fit stage and loading them during the predict stage?

There is actually some code to store trained models on disk and reload them in case training is restarted, but I did not implement it with the goal of saving memory in mind. The code is here https://github.com/scikit-learn-contrib/hiclass/blob/6f3799083b31e7ecdf504cd1c5fe3164874a9467/hiclass/HierarchicalClassifier.py#L367 and here https://github.com/scikit-learn-contrib/hiclass/blob/6f3799083b31e7ecdf504cd1c5fe3164874a9467/hiclass/LocalClassifierPerParentNode.py#L217 Would you be able to modify it to save memory with BERT and other models?

lpfgarcia commented 2 months ago

Thanks @mirand863

If you're interested, I implemented the flat approach this way:

from sklearn.base import BaseEstimator

class FlatClassifier(BaseEstimator):

    def __init__(self, local_classifier):
        self.local_classifier = local_classifier 

    def fit(self, X, y):
        y = ["::HiClass::Separator::".join(i) for i in y]
        self.local_classifier.fit(X, y)
        return self

    def predict(self, X):
        return [i.split('::HiClass::Separator::') for i in self.local_classifier.predict(X)]
mirand863 commented 2 months ago

Thanks @mirand863

If you're interested, I implemented the flat approach this way:

from sklearn.base import BaseEstimator

class FlatClassifier(BaseEstimator):

    def __init__(self, local_classifier):
        self.local_classifier = local_classifier 

    def fit(self, X, y):
        y = ["::HiClass::Separator::".join(i) for i in y]
        self.local_classifier.fit(X, y)
        return self

    def predict(self, X):
        return [i.split('::HiClass::Separator::') for i in self.local_classifier.predict(X)]

Hi @lpfgarcia,

I understand now what you mean. Sorry for the misunderstanding and thank you for clarifying. I will add this to the code base and put a comment to acknowledge your contribution, but if you would like your contribution to be properly acknowledged and listed on github you can open a pull request and I can review it.

Best regards, Fabio

lpfgarcia commented 2 months ago

Hi @mirand863

Very good! Feel free to add the code to the library.

Kind regards, Luis

mirand863 commented 1 month ago

Hi @lpfgarcia,

A quick update on this. I just added the flat classifier https://github.com/scikit-learn-contrib/hiclass/pull/128

I will try to tackle the problem with the memory leak on bert in the next days. My initial plan is to use the parameter tmp_dir or have something similar that stores the models directly to the disk and free it from memory after fit is successful. Please, let me know if you have any thoughts or different ideas.

Best regards, Fabio

lpfgarcia commented 1 month ago

Hi @mirand863 ,

Happy to hear about the addition of FlatClassifier.

Regarding the use of BERT with hiclass, I made a fork of the bert-classifier and a workaround in the code. Basically, after finishing the fit, I moved the model to the CPU instead of the GPU. This solved my problem.

Here is the link with the change: https://github.com/lpfgarcia/bert-sklearn/blob/b7cb9abcb123bdda743b2abc1ba70d7681276420/bert_sklearn/sklearn.py#L375

Kind regards, Luis