openspeech-team / openspeech

Open-Source Toolkit for End-to-End Speech Recognition leveraging PyTorch-Lightning and Hydra.
https://openspeech-team.github.io/openspeech/
MIT License
670 stars 112 forks source link

forward() got an unexpected keyword argument 'log_probs' #220

Open ChaofanTao opened 7 months ago

ChaofanTao commented 7 months ago

Environment info

Information

I want to train context-net on the librispeech dataset. Here is my training script located in openspeech/scripts: (First time I set dataset.dataset_download=True to download the dataset).

# sh scripts/train.sh 
python3 ./openspeech_cli/hydra_train.py \
    dataset=librispeech \
    dataset.dataset_download=False \
    dataset.dataset_path=$DATASET_PATH \
    dataset.manifest_file_path=$MANIFEST_FILE_PATH \
    tokenizer=libri_subword \
    model=contextnet \
    audio=fbank \
    lr_scheduler=warmup_reduce_lr_on_plateau \
    trainer=gpu \
   criterion=cross_entropy

It returns

-- Process 0 terminated with the following error:                                                               
Traceback (most recent call last):                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 
69, in _wrap                                                                                                    
    fn(i, *args)                                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers
/multiprocessing.py", line 139, in _wrapping_function                                                           
    results = function(*args, **kwargs)                                                                         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 645, in _fit_impl                                                                                         
    self._run(model, ckpt_path=self.ckpt_path)                                                                  
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1098, in _run      
     results = self._run_stage()                                                                                   File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1177, in _run_stage                                                                                       
    self._run_train()                                                                                           
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1190, in _run_train                                                                                           self._run_sanity_check()                                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1262, in _run_sanity_check                                                                                
    val_loop.run()                                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/eva
luation_loop.py", line 152, in advance                                                                          
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 137, in advance                                                                         
    output = self._evaluation_step(**kwargs)                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 234, in _evaluation_step                                                                
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1480, in _call_strategy_hook                                                                              
    output = fn(*args, **kwargs)                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp_spawn
.py", line 288, in validation_step                                                                              
    return self.model(*args, **kwargs)         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1156, in forward                                                                                              
    output = self._run_ddp_forward(*inputs, **kwargs)                                                             File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1110, in _run_ddp_forward                                                                                     
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", 
line 110, in forward                                                                                            
    return self._forward_module.validation_step(*inputs, **kwargs)                                              
  File "/home/mnt/cftao/openspeech/openspeech/models/contextnet/model.py", line 133, in validation_step         
    return self.collect_outputs(                                                                                
  File "/home/mnt/cftao/openspeech/openspeech/models/openspeech_ctc_model.py", line 73, in collect_outputs      
    loss = self.criterion(                                                                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
TypeError: forward() got an unexpected keyword argument 'log_probs'    

How to solve this problem? Thanks.

upskyy commented 7 months ago

@ChaofanTao Thank you for reporting the issue. I will check and leave a comment.