sooftware / conformer

[Unofficial] PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)
Apache License 2.0
958 stars 175 forks source link

error when reproducing the example of use (RuntimeError: Input tensor at index 1 has invalid shape [1, 3085, 8, 10], but expected [1, 3085, 9, 10]) #31

Closed sovse closed 3 years ago

sovse commented 3 years ago

Running the code results in an error:

import torch
print(torch.__version__)
import torch.nn as nn
from conformer import Conformer

batch_size, sequence_length, dim = 3, 12345, 80

cuda = torch.cuda.is_available()  
device = torch.device('cuda' if cuda else 'cpu')

inputs = torch.rand(batch_size, sequence_length, dim).to(device)
input_lengths = torch.IntTensor([12345, 12300, 12000])
targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2],
                            [1, 3, 3, 3, 3, 3, 4, 5, 2, 0],
                            [1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device)
target_lengths = torch.LongTensor([9, 8, 7])

model = nn.DataParallel(Conformer(num_classes=10, input_dim=dim, 
                                  encoder_dim=32, num_encoder_layers=3, 
                                  decoder_dim=32, device=device)).to(device)

# Forward propagate
outputs = model(inputs, input_lengths, targets, target_lengths)

# Recognize input speech
outputs = model.module.recognize(inputs, input_lengths)
1.9.0+cu111
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-12-eea3aeffaf58> in <module>
     21 
     22 # Forward propagate
---> 23 outputs = model(inputs, input_lengths, targets, target_lengths)
     24 
     25 # Recognize input speech

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    167             replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
    168             outputs = self.parallel_apply(replicas, inputs, kwargs)
--> 169             return self.gather(outputs, self.output_device)
    170 
    171     def replicate(self, module, device_ids):

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in gather(self, outputs, output_device)
    179 
    180     def gather(self, outputs, output_device):
--> 181         return gather(outputs, output_device, dim=self.dim)
    182 
    183 

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py in gather(outputs, target_device, dim)
     76     # Setting the function to None clears the refcycle.
     77     try:
---> 78         res = gather_map(outputs)
     79     finally:
     80         gather_map = None

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py in gather_map(outputs)
     61         out = outputs[0]
     62         if isinstance(out, torch.Tensor):
---> 63             return Gather.apply(target_device, dim, *outputs)
     64         if out is None:
     65             return None

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/_functions.py in forward(ctx, target_device, dim, *inputs)
     73             ctx.unsqueezed_scalar = False
     74         ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
---> 75         return comm.gather(inputs, ctx.dim, ctx.target_device)
     76 
     77     @staticmethod

/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/comm.py in gather(tensors, dim, destination, out)
    233                 'device object or string instead, e.g., "cpu".')
    234         destination = _get_device_index(destination, allow_cpu=True, optional=True)
--> 235         return torch._C._gather(tensors, dim, destination)
    236     else:
    237         if destination is not None:

RuntimeError: Input tensor at index 1 has invalid shape [1, 3085, 8, 10], but expected [1, 3085, 9, 10]

I am using version Python 3.8.8. Which version should it work with?

sooftware commented 3 years ago

How about result of single-gpu env?

sovse commented 3 years ago

Works correctly with one GPU. Thanks.