Closed liuchengwang closed 3 months ago
Hi! This is not something that can be done out of the box. Scikit-llm will handle the cases where the llm produces invalid output (e.g. a label that does not exist), but the response object itself must be valid.
However ignoring the error should be relatively straightforward to do by subclassing the classifier and re-implementing the predict method. It should look something like this:
from skllm.utils import to_numpy as _to_numpy
from tqdm import tqdm
class CustomClassifier(ZeroShotGPTClassifier):
def predict(self, X):
X = _to_numpy(X)
predictions = []
for i in tqdm(range(len(X))):
try:
p = self._predict_single(X[i])
except Exception:
p = "error"
predictions.append(p)
return predictions
Thanks, iryna-kondr. This issue is resolved as your code example is really helpful. Without your example, it's not smooth to understand how you wrap LLMs into your classifier step by step as I am new to scikit-llm.
Hi, I am new to scikit-llm. Generally spekaing,
ZeroShotGPTClassifier
works for my daily work as follows.clf= ZeroShotGPTClassifier(model=model_name)
clf.fit(X, y)
preds = clf.predict(X)
However, sometimes the input X may have few rows that are too long. It breaks my job by the error of 'context_length_exceeded', so I cannot getpreds
. Sometimes I fail to get predictions because few rows trigger the OpenAI error of 'content_filter'. (OpenAI's neural multi-class classification models believe my input text contains harmful content, but their predictions are false positives.)I think the error comes from the retry function. https://github.com/iryna-kondr/scikit-llm/blob/0bdea940fd369cdd5c5a0e625d3eea8f2b512208/skllm/utils.py#L92.
the classifier still works even with an error, but the prediction will be random for that case
.Thank you in advance.