opensearch-project / ml-commons

ml-commons provides a set of common machine learning algorithms, e.g. k-means, or linear regression, to help developers build ML related features within OpenSearch.
Apache License 2.0
97 stars 135 forks source link

[BUG] Failure when deploying finetuned BERT models #2179

Closed IanMenendez closed 8 months ago

IanMenendez commented 8 months ago

What is the bug? Opensearch fails when deploying finetuned BERT models

How can one reproduce the bug? Steps to reproduce the behavior:

  1. fine tune BERT model:
def model_finetuning(model_name):
    train_examples = []
    sample = InputExample(texts=["testing", "opensearch"], label = float(0.92))
    train_examples.append(sample)

    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)

    word_embedding_model = models.Transformer('sentence-transformers/' + model_name)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    loss_func = losses.CosineSimilarityLoss(model)

    model.fit(train_objectives=[(train_dataloader, loss_func)],
            epochs=2,
            save_best_model=True,
            output_path=f"./{model_name}" + str(datetime.utcnow()))

model_finetuning("all-MiniLM-L6-v2")
  1. Convert to TorchScript
import torch
import transformers
from transformers import BertTokenizer

model = transformers.AutoModelForSequenceClassification.from_pretrained("./path_to_finetuned_model", return_dict=False)
encoder = BertTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", return_dict=False)

text = "dummy text for tracing"
dummy_text_tensor = encoder.encode([text], return_tensors='pt')

model.eval()
traced_model = torch.jit.trace(model, dummy_text_tensor)

torch.jit.save(traced_model, "traced_bert.pt")
  1. ZIP traced_bert.pt and tokenizer.json (get this from the finetuned model)

  2. Upload model and deploy

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearchpy import OpenSearch

CLUSTER_URL = 'http://localhost:9200'

def get_os_client(cluster_url = CLUSTER_URL,
                  username='admin',
                  password='admin'):
    client = OpenSearch(
        hosts=[cluster_url],
        http_auth=(username, password),
        verify_certs=False
    )
    return client

ml_client = MLCommonClient(get_os_client())

model_path = 'traced_bert.zip'
model_config_path = 'config.json'
ml_client.register_model( model_path, model_config_path, isVerbose=True, deploy_model=True)

Error The error message is:

[2024-03-06T02:02:15,331][ERROR][o.o.m.e.a.DLModel        ] [2fb78d2ede13] Failed to deploy model VfJ_EY4BBpCbFD_3IEAT
ai.djl.translate.TranslateException: ai.djl.engine.EngineException: forward() Expected a value of type 'Tensor' for argument 'input_ids' but instead found type 'Dict[str, Tensor]'.
Position: 1
Declaration: forward(__torch__.transformers.models.bert.modeling_bert.BertForSequenceClassification self, Tensor input_ids) -> ((Tensor))
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:189) ~[api-0.21.0.jar:?]
    at ai.djl.inference.Predictor.predict(Predictor.java:126) ~[api-0.21.0.jar:?]
    at org.opensearch.ml.engine.algorithms.TextEmbeddingModel.warmUp(TextEmbeddingModel.java:51) ~[opensearch-ml-algorithms-2.12.0.0.jar:?]
    at org.opensearch.ml.engine.algorithms.DLModel.doLoadModel(DLModel.java:219) ~[opensearch-ml-algorithms-2.12.0.0.jar:?]
    at org.opensearch.ml.engine.algorithms.DLModel.lambda$loadModel$1(DLModel.java:280) [opensearch-ml-algorithms-2.12.0.0.jar:?]
    at java.base/java.security.AccessController.doPrivileged(AccessController.java:571) [?:?]
    at org.opensearch.ml.engine.algorithms.DLModel.loadModel(DLModel.java:247) [opensearch-ml-algorithms-2.12.0.0.jar:?]
    at org.opensearch.ml.engine.algorithms.DLModel.initModel(DLModel.java:139) [opensearch-ml-algorithms-2.12.0.0.jar:?]
    at org.opensearch.ml.engine.MLEngine.deploy(MLEngine.java:125) [opensearch-ml-algorithms-2.12.0.0.jar:?]
    at org.opensearch.ml.model.MLModelManager.lambda$deployModel$51(MLModelManager.java:1020) [opensearch-ml-2.12.0.0.jar:2.12.0.0]
    at org.opensearch.core.action.ActionListener$1.onResponse(ActionListener.java:82) [opensearch-core-2.12.0.jar:2.12.0]
    at org.opensearch.ml.model.MLModelManager.lambda$retrieveModelChunks$72(MLModelManager.java:1553) [opensearch-ml-2.12.0.0.jar:2.12.0.0]
    at org.opensearch.core.action.ActionListener$1.onResponse(ActionListener.java:82) [opensearch-core-2.12.0.jar:2.12.0]
    at org.opensearch.action.support.ThreadedActionListener$1.doRun(ThreadedActionListener.java:78) [opensearch-2.12.0.jar:2.12.0]
    at org.opensearch.common.util.concurrent.ThreadContext$ContextPreservingAbstractRunnable.doRun(ThreadContext.java:913) [opensearch-2.12.0.jar:2.12.0]
    at org.opensearch.common.util.concurrent.AbstractRunnable.run(AbstractRunnable.java:52) [opensearch-2.12.0.jar:2.12.0]
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144) [?:?]
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642) [?:?]
    at java.base/java.lang.Thread.run(Thread.java:1583) [?:?]
Caused by: ai.djl.engine.EngineException: forward() Expected a value of type 'Tensor' for argument 'input_ids' but instead found type 'Dict[str, Tensor]'.
Position: 1
Declaration: forward(__torch__.transformers.models.bert.modeling_bert.BertForSequenceClassification self, Tensor input_ids) -> ((Tensor))
    at ai.djl.pytorch.jni.PyTorchLibrary.moduleRunMethod(Native Method) ~[pytorch-engine-0.21.0.jar:?]
    at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:53) ~[pytorch-engine-0.21.0.jar:?]
    at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:145) ~[pytorch-engine-0.21.0.jar:?]
    at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79) ~[api-0.21.0.jar:?]
    at ai.djl.nn.Block.forward(Block.java:127) ~[api-0.21.0.jar:?]
    at ai.djl.inference.Predictor.predictInternal(Predictor.java:140) ~[api-0.21.0.jar:?]
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:180) ~[api-0.21.0.jar:?]
    ... 18 more

What is the expected behavior? Model should be deployed

What is your host/environment?

IanMenendez commented 8 months ago

Fixed this by creating a simple custom model and mapping the inputs and outputs to the needed ones.

from transformers import BertModel

class CustomBertModel(BertModel):
    def forward(self, input_dict, *args, **kwargs):
        input_ids = input_dict['input_ids']

        result = super().forward(input_ids, *args, **kwargs)

        sentence_embedding = result[1]
        return {'sentence_embedding': sentence_embedding}