huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.64k stars 26.43k forks source link

Couldn't process request when using "automatic-speech-recognition" pipeline on SageMaker #19743

Closed wildgeece96 closed 1 year ago

wildgeece96 commented 1 year ago

System Info

transformers

torch

sagemaker

Who can help?

@Narsil @patrickvonplaten
@anton-l

Information

Tasks

Reproduction

Run below code on SageMaker.

from sagemaker.huggingface import HuggingFaceModel
import sagemaker
import numpy as np

role = sagemaker.get_execution_role()
# Hub Model configuration. https://huggingface.co/models
hub = {
    # 'HF_MODEL_ID':'openai/whisper-base',
    'HF_MODEL_ID': 'facebook/wav2vec2-base-960h',
    'HF_TASK':'automatic-speech-recognition'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.17.0',
    pytorch_version='1.10.2',
    py_version='py38',
    env=hub,
    role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.m5.xlarge' # ec2 instance type
)

input_array = np.random.randn(1, 10000)
predictor.predict({
    'inputs': input_array
})

Returned InternalServerError,


2022-10-19T09:42:32,618 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Prediction error
--
2022-10-19T09:42:32,619 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-10-19T09:42:32,619 [INFO ] W-9000-facebook__wav2vec2-base-9 com.amazonaws.ml.mms.wlm.WorkerThread - Backend response time: 9
2022-10-19T09:42:32,619 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 234, in handle
2022-10-19T09:42:32,621 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     response = self.transform_fn(self.model, input_data, content_type, accept)
2022-10-19T09:42:32,621 [INFO ] W-9000-facebook__wav2vec2-base-9 ACCESS_LOG - /169.254.178.2:42092 "POST /invocations HTTP/1.1" 400 16
2022-10-19T09:42:32,622 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 190, in transform_fn
2022-10-19T09:42:32,623 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     predictions = self.predict(processed_data, model)
2022-10-19T09:42:32,623 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 158, in predict
2022-10-19T09:42:32,624 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     prediction = model(inputs)
2022-10-19T09:42:32,624 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/automatic_speech_recognition.py", line 168, in __call__
2022-10-19T09:42:32,624 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     return super().__call__(inputs, **kwargs)
2022-10-19T09:42:32,625 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1016, in __call__
2022-10-19T09:42:32,625 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     outputs = [output for output in final_iterator]
2022-10-19T09:42:32,626 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1016, in <listcomp>
2022-10-19T09:42:32,626 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     outputs = [output for output in final_iterator]
2022-10-19T09:42:32,626 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py", line 111, in __next__
2022-10-19T09:42:32,627 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     item = next(self.iterator)
2022-10-19T09:42:32,627 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py", line 253, in __next__
2022-10-19T09:42:32,627 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     processed = self.infer(next(self.iterator), **self.params)
2022-10-19T09:42:32,628 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
2022-10-19T09:42:32,628 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     data = self._next_data()
2022-10-19T09:42:32,631 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
2022-10-19T09:42:32,631 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
2022-10-19T09:42:32,631 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
2022-10-19T09:42:32,631 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     data.append(next(self.dataset_iter))
2022-10-19T09:42:32,631 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py", line 170, in __next__
2022-10-19T09:42:32,632 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     processed = next(self.subiterator)
2022-10-19T09:42:32,632 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/automatic_speech_recognition.py", line 222, in preprocess
2022-10-19T09:42:32,632 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
2022-10-19T09:42:32,632 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - ValueError: We expect a numpy ndarray as input, got `<class 'list'>`
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During handling of the above exception, another exception occurred:
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/mms/service.py", line 108, in predict
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch, self.context)
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.8/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 243, in handle
2022-10-19T09:42:32,633 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e), 400)
2022-10-19T09:42:32,634 [INFO ] W-facebook__wav2vec2-base-9-3-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: We expect a numpy ndarray as input, got `<class 'list'>` : 400

Expected behavior

When I use Transformers on SageMaker, I noticed that Automatic Speech Recognition Pipeline doesn't consider receiving requests when deployed on SageMaker.

When we use SageMaker HuggingFace Inference Toolkit, pipelines will be used for inference.

AutomaticSpeechRecognitionPipeline doesn't accept list as inputs parameter for __call__ method and via API the request body is supposed to be like,

{
    "inputs": np.array([1., 2., 3., 4.]))
}

but I cannot pass ndarray via JSON Serializer I can only pass list.

To solve that problem, pipeline should accept list as inputs.

return like

{'text': 'UROUND ME ON YOU E'}
Narsil commented 1 year ago

Hi @wildgeece96 .

The np.array is supposed to be the raw audio waveform in the correct sampling rate, right ?

If so, then it seems the bug comes from somewhere around sagemaker where the numpy array gets converted to a list.

I am tentatively against adding support for lists instead of numpy arrays:

That being said there could be workaround probably: Would using a wav file work for you ?

https://stackoverflow.com/questions/51300281/how-to-convert-numpy-array-to-bytes-object-without-save-audio-file-on-disk

Couldn't find better code fast with my google fu, but it's probably doable to create a Wav like buffer with minimal reallocations. Does the sagemaker allow sending raw bytes ? Would that approach work ?

wildgeece96 commented 1 year ago

I confirmed inference code like below works

from transformers import pipeline
from transformers.pipelines import AutomaticSpeechRecognitionPipeline
import numpy as np

def model_fn(model_dir) -> AutomaticSpeechRecognitionPipeline:
    return pipeline(model="facebook/wav2vec2-base-960h")  

def predict_fn(data, pipeline):
    inputs = data.pop("inputs", data)
    parameters = data.pop("parameters", None)
    if type(inputs) == list:
        inputs = np.array(inputs, dtype=np.float)
    print("inputs are: ", inputs)
    # pass inputs with all kwargs in data
    if parameters is not None:
        prediction = pipeline(inputs, **parameters)
    else:
        prediction = pipeline(inputs)
    return prediction
wildgeece96 commented 1 year ago

Thanks @Narsil .

Actually, in my use case, I deployed wav2vec model on SageMaker, and when I send request via SageMaker SDK seriealizer of SageMaker (like JSONSerializer, NumPySerializer) serialize the input to throw request to the endpoint. I should use JSONSerialization to use SageMaker HuggingFace Inference Toolkit and JSONSeiralizer cannot pass ndarray as it is but convert to list.

After reading your comment, the converting logic should be implemented on SageMaker HuggingFace Inference Toolkit because it's specific for SageMaker use case.

philschmid commented 1 year ago

Hello @wildgeece96 the automatic-speech-recognition pipeline is supported. instead of sending numpy data you need to send the audio. Check out this example: https://github.com/huggingface/notebooks/blob/main/sagemaker/20_automatic_speech_recognition_inference/sagemaker-notebook.ipynb

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.