I am trying to run training SpatialNet with SMS_WSJ data. I use an example from your README:
python SharedTrainer.py fit \
--config=configs/SpatialNet.yaml \ # network config
--config=configs/datasets/sms_wsj_plus.yaml \ # dataset config
--model.channels=[0,1,2,3,4,5] \ # the channels used
--model.arch.dim_input=12 \ # input dim per T-F point, i.e. 2 * the number of channels
--model.arch.dim_output=4 \ # output dim per T-F point, i.e. 2 * the number of sources
--model.arch.num_freqs=129 \ # the number of frequencies, related to model.stft.n_fft
--trainer.precision=bf16-mixed \ # mixed precision training, can also be 16-mixed or 32, where 32 can produce the best performance
--model.compile=true \ # compile the network, requires torch>=2.0. the compiled model is trained much faster
--data.batch_size=[2,4] \ # batch size for train and val
--trainer.devices=0, \
--trainer.max_epochs=100
but has faced some troubles.
Whet attempting run it I got multiple errors about unknown parameters. I invoked python SharedTrainer.py fit --help that parameter values are expected to be passed with space, not =. Moreover, configs for model and data are expected to be passed via --trainer and --data, not --config.
After fixing above issues and running again I got a following error:
error: Parser key "data":
Not a valid subclass of LightningDataModule.
.......
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
I looked into configs/datasets/sms_wsj_plus.yaml and found that it indeed describes a dict without a key 'class_path', only value of a key 'data' does have such a key:
3. So I removed a top-level key 'data' for dict to have key 'class_path'. But now I am getting another error, namely
error: Parser key "data":
'type' object is not subscriptable
I tried to replace path to yaml with actual class path data_loaders.sms_wsj_plus.SmsWsjPlusDataModule for the --data parameter but the same error is obtained.
Please, explain what I am doing wrong
Best regards,
Maxim
Hi,
I am trying to run training SpatialNet with SMS_WSJ data. I use an example from your README:
but has faced some troubles.
python SharedTrainer.py fit --help
that parameter values are expected to be passed with space, not =. Moreover, configs for model and data are expected to be passed via --trainer and --data, not --config.data: class_path: data_loaders.sms_wsj_plus.SmsWsjPlusDataModule init_args: sms_wsj_dir: data/sms_wsj/data rir_dir: datasets/SMS_WSJ_Plus_rirs/ target: direct_path datasets: ["train_si284", "test_dev93", "test_eval92", "test_eval92"] audio_time_len: [4.0, 4.0, null, null] ovlp: mid speech_overlap_ratio: [0.1, 1.0] sir: [-5, 5] snr: [0, 20] num_spk: 2 noise_type: ["babble", "white"] batch_size: [2, 1]
error: Parser key "data": 'type' object is not subscriptable