Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

Data import for inference with Speech Recognition Task only accepts filepath #867

Closed marischu closed 2 years ago

marischu commented 3 years ago

🐛 Bug

SpeechRecognition Task predict() can not operate on raw tensor input.

To Reproduce

Steps to reproduce the behavior:

  1. Load a wav2vec2 pytorch model
  2. Load in audio data for inference and put in a tensor.
  3. Run model.predict([input_tensor]) or model.predict({"input_values": input_tensor}
  4. See error
    return filename.lower().endswith(extensions)
    AttributeError: 'Tensor' object has no attribute 'lower'

    or for the dictionary input:

    "..\lib\site-packages\flash\core\data\auto_dataset.py" line 98, in __getitem__
    return self._call_load_sample(self.data[index])
    KeyError: 0

Code sample

import torchaudio

from flash.audio import SpeechRecognition

data = torchaudio.datasets.LIBRISPEECH(".\\data",url="test-clean", download=True)

model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# both of the following will fail
res = model.predict({"input_values": data[0][0]})
res = model.predict(data[0][0])
print(res)

Expected behavior

It was expected that the model.predict() function would accept a tensor or input dictionary since that is the "raw data" referencing how "predict" is documented to work:

https://github.com/PyTorchLightning/lightning-flash/blob/4b24b445ce9a65b21ab3fc0361d95d879f01cedf/flash/core/model.py#L466

Environment

Additional context

It is possible to bypass this error by: 1) Sanitizing the input by subclassing the SpeechRecognition task

class W2v2_model(SpeechRecognition):      
 def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None) -> Any:
        if isinstance(batch, torch.Tensor):
            return super()._apply_batch_transfer_handler(batch={DefaultDataKeys.INPUT: batch}, device=device, dataloader_idx=dataloader_idx)[DefaultDataKeys.INPUT]
        else:
            return super()._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)

and 2) declaring a default datapipeline to pass to the predict stage.

pipeline = model.build_data_pipeline("default")
model.predict({"input_values":data[0][0]},data_pipeline=pipeline)

This works, but it seems to me that this is a bug for the following reasons:

1) passing a data pipeline to the predict function shouldn't change the interface of the function itself 2) The docs suggest that either a filepath or raw data should work with predict (this may be a misreading on my part) 3) The way predict() works (creating a pipeline on every call) is too slow for my purposes (benchmarking), which is why I couldn't use the default behavior of creating a csv so I can load a datamodule in the first place.

tchaton commented 3 years ago

You can do model.predict(data, data_source="tensor"). Mind trying ?

tchaton commented 3 years ago

I would advice you to use trainer.predict and provide a DataLoader for speed concerns

marischu commented 3 years ago

If I use:

res = model.predict(data[0][0],data_source="tensor")

then I still have the error:

IndexError: too many indices for tensor of dimension 1

Calling shape on the data sample yields:

torch.Size([1, 166960])

If I remove the first dimension, the error is:

IndexError: too many indices for tensor of dimension 0

I understand that the trainer.predict and DataLoader will work, and I have found other workarounds as well, but my concern is that if that is the only way to use this system, the documentation should reflect that.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.