huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.9k stars 26.78k forks source link

Not handled case when use_weighted_layer_sum and return-dict=True in WhisperForAudioClassification #28002

Closed ElsebaiyMohamed closed 9 months ago

ElsebaiyMohamed commented 10 months ago

@sanchit-gandhi I use WhisperForAudioClassification task and want to use use_weighted_layer_sum=True, but there is a problem when call forward, the encoder part can return tuple or dict if return_dict=True but the code for use use_weighted_layer_sum=True assume the return to be tuple only and this line raise error hidden_states = torch.stack(encoder_outputs, dim=1) if the encoder return dict, there are workaround by using return_dict=False but when use the model later with pipeline it will raise error because it assume the model to return dict not tuple. Link to code with the problem

        if self.config.use_weighted_layer_sum:
            hidden_states = torch.stack(encoder_outputs, dim=1) # This line raise error when return_dict=True and use_weighted_layer_sum=True
            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
        else:
            hidden_states = encoder_outputs[0]

Reproduce error

from transformers import  WhisperForAudioClassification, AutoFeatureExtractor
from datasets import load_dataset

dataset = load_dataset('seba3y/speechocean762',)
dataset = dataset['train']
sampling_rate = dataset.features["audio"].sampling_rate
dataset = dataset.remove_columns(['utt_name', 'text', 'completeness', 'fluency', 'prosodic'])

feature_extractor = AutoFeatureExtractor.from_pretrained("seba3y/whisper-tiny")
model = WhisperForAudioClassification.from_pretrained("seba3y/whisper-tiny",
                                                       use_weighted_layer_sum=True, 
                                                       return_dict=True)
# test if it work
inputs = feature_extractor(dataset['train'][3]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
predicted_class_ids = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_ids]
print(predicted_label)
amyeroberts commented 10 months ago

Hi @ElsebaiyMohamed, thanks for raising this issue and providing details on the error + a snippet. Could you also provide information about the running environment: run transformers-cli env in the terminal and copy-paste the output?

ElsebaiyMohamed commented 10 months ago

Hi @amyeroberts , Apologies for the delayed response! 🙏 Life threw a curveball, but I'm back on track. Thanks for your patience!

Regarding your request, here's the output of transformers-cli env:

transformers version: 4.36.0
Platform: Linux-5.15.133+-x86_64-with-glibc2.35
Python version: 3.10.12
Huggingface_hub version: 0.19.4
Safetensors version: 0.4.1
Accelerate version: 0.25.0
Accelerate config:  not found
PyTorch version (GPU?): 2.0.0 (True)
Tensorflow version (GPU?): 2.13.0 (True)
Flax version (CPU?/GPU?/TPU?): 0.7.5 (gpu)
Jax version: 0.4.21
JaxLib version: 0.4.21
Using GPU in script?: yes
Using distributed or parallel set-up in script?: no

Let me know if there's anything else I can help you with.

amyeroberts commented 9 months ago

@ElsebaiyMohamed Great - thanks for providing this info!

cc @sanchit-gandhi @ylacombe