Audio-WestlakeU / FullSubNet

PyTorch implementation of "FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."
https://fullsubnet.readthedocs.io/en/latest/
MIT License
554 stars 158 forks source link

Training is wrong #30

Open Beninmiao opened 3 years ago

Beninmiao commented 3 years ago

I loaded your code to train the whole data. But, the process pauses after loading train.toml Do you test correctness of your code.

haoxiangsnr commented 3 years ago

Hi Beninmiao,

Thanks for your attention. Please tell me more information about the error you encountered and detailed information about your platform as detailed as possible.

Malik7115 commented 3 years ago

Same for me. It gets stuck here: Screenshot from 2021-10-26 15-52-29

apparently it is not moving onwards from this in recipes/dns_interspeech_2020/train.py (line59):

   trainer = trainer_class(
        dist=dist,
        rank=rank,
        config=config,
        resume=resume,
        only_validation=only_validation,
        model=model,
        loss_function=loss_function,
        optimizer=optimizer,
        train_dataloader=train_dataloader,
        validation_dataloader=valid_dataloader
    )
haoxiangsnr commented 3 years ago

Hi, @Malik7115.

Sorry for the late reply.

The latest FullSubNet is based on PyTorch 1.10 and uses the torchrun (Elastic Launch) as the launcher.

I've modified the out-of-date documents. Please update your Pytorch version to v1.10 and retry experiments.

If you have further questions, please feel free to contact me.

gooran commented 3 years ago

I have encountered this error while training the model. What is the reason?

ZeroDivisionError: float division by zero
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 22651) of binary: 

This is the code: https://colab.research.google.com/drive/1xf3JWWFazfi6qh4eIXgjayB0qR16NziC?usp=sharing

And this is the full log:

1 process initialized.
The configurations are as follows: 
{
    'meta': {
        'save_dir': '~/Experiments/FullSubNet/',
        'description': 'This is a description of FullSubNet experiment.',
        'seed': 0,
        'use_amp': True,
        'cudnn_enable': False,
        'experiment_name': 'train',
        'config_path': 'fullsubnet/train.toml',
        'preloaded_model_path': None
    },
    'acoustics': {
        'n_fft': 512,
        'win_length': 512,
        'sr': 16000,
        'hop_length': 256
    },
    'loss_function': {'name': 'mse_loss', 'args': {}},
    'optimizer': {'lr': 0.001, 'beta1': 0.9, 'beta2': 0.999},
    'train_dataset': {
        'path': 'dataset_train.Dataset',
        'args': {
            'clean_dataset': 
'~/Datasets/DNS-Challenge-INTERSPEECH/datasets/clean_0.6.txt',
            'clean_dataset_limit': False,
            'clean_dataset_offset': 0,
            'noise_dataset': 
'~/Datasets/DNS-Challenge-INTERSPEECH/datasets/noise.txt',
            'noise_dataset_limit': False,
            'noise_dataset_offset': 0,
            'num_workers': 36,
            'pre_load_clean_dataset': False,
            'pre_load_noise': False,
            'pre_load_rir': False,
            'reverb_proportion': 0.75,
            'rir_dataset': 
'~/Datasets/DNS-Challenge-INTERSPEECH/datasets/rir.txt',
            'rir_dataset_limit': False,
            'rir_dataset_offset': 0,
            'silence_length': 0.2,
            'snr_range': [-5, 20],
            'sr': 16000,
            'sub_sample_length': 3.072,
            'target_dB_FS': -25,
            'target_dB_FS_floating_value': 10
        },
        'dataloader': {
            'batch_size': 9,
            'num_workers': 36,
            'drop_last': True,
            'pin_memory': False
        }
    },
    'validation_dataset': {
        'path': 'dataset_validation.Dataset',
        'args': {
            'dataset_dir_list': [
                '~/Datasets/DNS-Challenge-INTERSPEECH/datasets/test_set/syntheti
c/with_reverb/',
                '~/Datasets/DNS-Challenge-INTERSPEECH/datasets/test_set/syntheti
c/no_reverb/'
            ],
            'sr': 16000
        }
    },
    'model': {
        'path': 'fullsubnet.model.Model',
        'args': {
            'sb_num_neighbors': 15,
            'fb_num_neighbors': 0,
            'num_freqs': 257,
            'look_ahead': 2,
            'sequence_model': 'LSTM',
            'fb_output_activate_function': 'ReLU',
            'sb_output_activate_function': False,
            'fb_model_hidden_size': 512,
            'sb_model_hidden_size': 384,
            'weight_init': False,
            'norm_type': 'offline_laplace_norm',
            'num_groups_in_drop_band': 4
        }
    },
    'trainer': {
        'path': 'trainer.Trainer',
        'train': {
            'clip_grad_norm_value': 10,
            'epochs': 9999,
            'save_checkpoint_interval': 2
        },
        'validation': {
            'save_max_metric_score': True,
            'validation_interval': 2
        },
        'visualization': {
            'metrics': ['WB_PESQ', 'NB_PESQ', 'STOI', 'SI_SDR'],
            'n_samples': 10,
            'num_workers': 36
        }
    }
}
This project contains 1 models, the number of the parameters is: 
        Network 1: 5.637635 million.
The amount of parameters in the project is 5.637635 million.
=============== 1 epoch ===============
[0 seconds] Begin training...
[27 seconds] This epoch is finished.
=============== 2 epoch ===============
[0 seconds] Begin training...
         Saving 2 epoch model checkpoint...
[28 seconds] Training has finished, validation is in progress...
/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 36 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
Training:   0%|          | 0/13 [00:00<?, ?it/s]/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 36 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
Training: 100%|██████████| 13/13 [00:27<00:00,  2.15s/it]
Training: 100%|██████████| 13/13 [00:28<00:00,  2.16s/it]
Validation: 0it [00:00, ?it/s]
Traceback (most recent call last):
  File "train.py", line 99, in <module>
    entry(local_rank, configuration, args.resume, args.only_validation)
  File "train.py", line 72, in entry
    trainer.train()
  File "/content/FullSubNet/audio_zen/trainer/base_trainer.py", line 337, in train
    metric_score = self._validation_epoch(epoch)
  File "/usr/local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/content/FullSubNet/recipes/dns_interspeech_2020/fullsubnet/trainer.py", line 111, in _validation_epoch
    self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch)
ZeroDivisionError: float division by zero
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 22651) of binary: /usr/local/bin/python
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/run.py", line 723, in <module>
    main()
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/run.py", line 719, in main
    run(args)
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/run.py", line 713, in run
    )(*cmd_args)
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 261, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-11-03_18:13:48
  host      : daf4d85e2fc7
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 22651)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
haoxiangsnr commented 3 years ago

Hi @gooran,

Thanks for your attention.

I noticed that you interrupted the clone processing in this cell. As a result, you may haven't downloaded any validation audio files in your ~/Datasets/DNS-Challenge INTERSPEECH/datasets/test_set/synthetic/ directory. Checking the output of the training processing, it contains the following information, that is without validation audio file:

...

Validation: 0it [00:00, ?it/s]

...

  File "/content/FullSubNet/recipes/dns_interspeech_2020/fullsubnet/trainer.py", line 111, in _validation_epoch
    self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch)
ZeroDivisionError: float division by zero

...

If you have further questions, please feel free to contact me.

gooran commented 3 years ago

Hi @gooran,

Thanks for your attention.

I noticed that you interrupted the clone processing in this cell. As a result, you may haven't downloaded any validation audio files in your ~/Datasets/DNS-Challenge INTERSPEECH/datasets/test_set/synthetic/ directory. Checking the output of the training processing, it contains the following information, that is without validation audio file:

...

Validation: 0it [00:00, ?it/s]

...

  File "/content/FullSubNet/recipes/dns_interspeech_2020/fullsubnet/trainer.py", line 111, in _validation_epoch
    self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch)
ZeroDivisionError: float division by zero

...

If you have further questions, please feel free to contact me.

Hi, Thank you for your kind cooperation. I'm working on it...