pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 652 forks source link

Fix transducer test #1251

Closed mthrok closed 3 years ago

mthrok commented 3 years ago

RNN-Transducer is a recently added prototype module and users can opt-in/out when building torchaudio.

The unit test for RNNT is supposed to be skipped when RNNT is not available, but instead it skips only when C++ extension is not available. It still runs test when C++ extension is available but not compiled with RNNT.

https://github.com/pytorch/audio/blob/77de2b9616c47c114124a8e002447ad927799dbd/test/torchaudio_unittest/transducer_test.py#L274-L275

Steps for fix

  1. Replace the skipIfNoExtension decorator with one with logic specific to RNNT-loss The implementation should check if torch.ops.torchaudio.rnnt_loss is accessible. (You can use try ~ except ~ clause.) The new decorator should live in the same file as transducer_test.py. (Let's call this skipIfNoRNNT)
  2. Replace PytorchTestCase with TorchaudioTestCase This is not related to the test skip logic, but PytorchTestCase is not the right class to be used here, so let's fix it.

Build and test

  1. Install nightly build of PyTorch
  2. Clone repo
  3. Run test with pytest test/torchaudio_unittest/transducer_test.py -v and confirm that test is not skipped and failed. (The test should not pass)
  4. Fix the issue
  5. Run the test again and see that it is skipped properly

Before the fix, it should fail with a message like RuntimeError: No such operator torchaudio::rnnt_loss as follow;

self = <[RuntimeError('No such operator torchaudio::__file__') raised in repr()] _OpNamespace object at 0x7f31d9e003b0>, op_name = 'rnnt_loss'

    def __getattr__(self, op_name):
        # Get the op `my_namespace::my_op` if available. This will also check
        # for overloads and raise an exception if there are more than one.
        qualified_op_name = '{}::{}'.format(self.name, op_name)
>       op = torch._C._jit_get_operation(qualified_op_name)
E       RuntimeError: No such operator torchaudio::rnnt_loss

/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/_ops.py:61: RuntimeError

After the fix, the test should be properly skipped.

jieruan02 commented 3 years ago

can I try this? also pytest test/torchaudio_unittest/transducer_test.py -vpass for my case, is it because I install with pip?