Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
232 stars 26 forks source link

Unable to run training #22

Closed kfmn closed 8 months ago

kfmn commented 8 months ago

Hi,

I am trying to run training SpatialNet with SMS_WSJ data. I use an example from your README:

python SharedTrainer.py fit \
 --config=configs/SpatialNet.yaml \ # network config
 --config=configs/datasets/sms_wsj_plus.yaml \ # dataset config
 --model.channels=[0,1,2,3,4,5] \ # the channels used
 --model.arch.dim_input=12 \ # input dim per T-F point, i.e. 2 * the number of channels
 --model.arch.dim_output=4 \ # output dim per T-F point, i.e. 2 * the number of sources
 --model.arch.num_freqs=129 \ # the number of frequencies, related to model.stft.n_fft
 --trainer.precision=bf16-mixed \ # mixed precision training, can also be 16-mixed or 32, where 32 can produce the best performance
 --model.compile=true \ # compile the network, requires torch>=2.0. the compiled model is trained much faster
 --data.batch_size=[2,4] \ # batch size for train and val
 --trainer.devices=0, \
 --trainer.max_epochs=100

but has faced some troubles.

  1. Whet attempting run it I got multiple errors about unknown parameters. I invoked python SharedTrainer.py fit --help that parameter values are expected to be passed with space, not =. Moreover, configs for model and data are expected to be passed via --trainer and --data, not --config.
  2. After fixing above issues and running again I got a following error:
    
    error: Parser key "data":
    Not a valid subclass of LightningDataModule. 
    .......
    Subclass types expect one of:
    - a class path (str)
    - a dict with class_path entry
    - a dict without class_path but with init_args entry (class path given previously)
I looked into configs/datasets/sms_wsj_plus.yaml and found that it indeed describes a dict without a key 'class_path', only value of a key 'data' does have such a key:

data: class_path: data_loaders.sms_wsj_plus.SmsWsjPlusDataModule init_args: sms_wsj_dir: data/sms_wsj/data rir_dir: datasets/SMS_WSJ_Plus_rirs/ target: direct_path datasets: ["train_si284", "test_dev93", "test_eval92", "test_eval92"] audio_time_len: [4.0, 4.0, null, null] ovlp: mid speech_overlap_ratio: [0.1, 1.0] sir: [-5, 5] snr: [0, 20] num_spk: 2 noise_type: ["babble", "white"] batch_size: [2, 1]

3. So I removed a top-level key 'data' for dict to have key 'class_path'. But now I am getting another error, namely 

error: Parser key "data": 'type' object is not subscriptable


I tried to replace path to yaml with actual class path data_loaders.sms_wsj_plus.SmsWsjPlusDataModule for the --data parameter but the same error is obtained.

Please, explain what I am doing wrong

Best regards,
Maxim
kfmn commented 8 months ago

It seems the problem was in python version (3.8). It works with 3.10

SinoHero commented 4 months ago

Hello! Did you solve the problem? I'm facing the same thing.