pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.43k stars 636 forks source link

Bugfix: Enforcing batch as type Batch for librispeech_conformer_rnnt … #3708

Closed kikofmas closed 3 months ago

kikofmas commented 7 months ago

Hi

I created an issue (#3707) where I detailed a bug with the librispeech_conformer_rnnt ASR example. After some digging I found the reason for that error: the variable batch looses the type Batch (named tuple) when a function is called.

On the file lightning.py we have:

#
# ...
#

    def _step(self, batch, _, step_type):
        if batch_in is None:
            return None

        prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])

#
# ...
#

    def training_step(self, batch: Batch, batch_idx):
#
# ...
#
        loss = self._step(batch, batch_idx, "train")
        batch_size = batch.features.size(0)
#
# ...
#

A simple recast of the tuple solves this. I changed it to:

#
# ...
#

    def _step(self, batch: Batch, _, step_type):
        if batch is None:
            return None

        batch = Batch(batch[0], batch[1], batch[2], batch[3])

        prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])

#
# ...
#
    def training_step(self, batch: Batch, batch_idx):
#
# ...
#
        batch = Batch(batch[0], batch[1], batch[2], batch[3])
        loss = self._step(batch, batch_idx, "train")
        batch_size = batch.features.size(0)
#
# ...
#

I don't think this addresses the root cause of the problem but I hope it is a usefull PR to, at least, make the example run. Thank you for all your work.

(edit: put code inseid a code block)

pytorch-bot[bot] commented 7 months ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/audio/3708

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.