iryna-kondr / scikit-llm

Seamlessly integrate LLMs into scikit-learn.
https://beastbyte.ai/
MIT License
2.98k stars 235 forks source link

Is that possible not to raise error for few edge cases? #87

Closed liuchengwang closed 3 months ago

liuchengwang commented 4 months ago

Hi, I am new to scikit-llm. Generally spekaing, ZeroShotGPTClassifier works for my daily work as follows. clf= ZeroShotGPTClassifier(model=model_name) clf.fit(X, y) preds = clf.predict(X) However, sometimes the input X may have few rows that are too long. It breaks my job by the error of 'context_length_exceeded', so I cannot get preds. Sometimes I fail to get predictions because few rows trigger the OpenAI error of 'content_filter'. (OpenAI's neural multi-class classification models believe my input text contains harmful content, but their predictions are false positives.)

I think the error comes from the retry function. https://github.com/iryna-kondr/scikit-llm/blob/0bdea940fd369cdd5c5a0e625d3eea8f2b512208/skllm/utils.py#L92.

  1. Is there a quick way that I can turn off this error arise in the version 1.0.0?
  2. I am OK if the classification predictions are random if OpenAI API returns any error. Is that doable? My memory may not be accurate, but I remember the old manual of scikit-llm had something like the classifier still works even with an error, but the prediction will be random for that case.

Thank you in advance.

iryna-kondr commented 3 months ago

Hi! This is not something that can be done out of the box. Scikit-llm will handle the cases where the llm produces invalid output (e.g. a label that does not exist), but the response object itself must be valid.

However ignoring the error should be relatively straightforward to do by subclassing the classifier and re-implementing the predict method. It should look something like this:

from skllm.utils import to_numpy as _to_numpy
from tqdm import tqdm

class CustomClassifier(ZeroShotGPTClassifier):
    def predict(self, X):
        X = _to_numpy(X)
        predictions = []
        for i in tqdm(range(len(X))):
            try:
                p = self._predict_single(X[i])
            except Exception:
                p = "error"
            predictions.append(p)
        return predictions
liuchengwang commented 3 months ago

Thanks, iryna-kondr. This issue is resolved as your code example is really helpful. Without your example, it's not smooth to understand how you wrap LLMs into your classifier step by step as I am new to scikit-llm.