Closed marischu closed 2 years ago
You can do model.predict(data, data_source="tensor"). Mind trying ?
I would advice you to use trainer.predict and provide a DataLoader for speed concerns
If I use:
res = model.predict(data[0][0],data_source="tensor")
then I still have the error:
IndexError: too many indices for tensor of dimension 1
Calling shape on the data sample yields:
torch.Size([1, 166960])
If I remove the first dimension, the error is:
IndexError: too many indices for tensor of dimension 0
I understand that the trainer.predict and DataLoader will work, and I have found other workarounds as well, but my concern is that if that is the only way to use this system, the documentation should reflect that.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🐛 Bug
SpeechRecognition Task predict() can not operate on raw tensor input.
To Reproduce
Steps to reproduce the behavior:
model.predict([input_tensor])
ormodel.predict({"input_values": input_tensor}
or for the dictionary input:
Code sample
Expected behavior
It was expected that the
model.predict()
function would accept a tensor or input dictionary since that is the "raw data" referencing how "predict" is documented to work:https://github.com/PyTorchLightning/lightning-flash/blob/4b24b445ce9a65b21ab3fc0361d95d879f01cedf/flash/core/model.py#L466
Environment
Additional context
It is possible to bypass this error by: 1) Sanitizing the input by subclassing the SpeechRecognition task
and 2) declaring a default datapipeline to pass to the predict stage.
This works, but it seems to me that this is a bug for the following reasons:
1) passing a data pipeline to the predict function shouldn't change the interface of the function itself 2) The docs suggest that either a filepath or raw data should work with predict (this may be a misreading on my part) 3) The way
predict()
works (creating a pipeline on every call) is too slow for my purposes (benchmarking), which is why I couldn't use the default behavior of creating a csv so I can load a datamodule in the first place.