Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
175 stars 21 forks source link

关于SpatialNet 参数量变大从而对性能提升的上限 #21

Open rookie0607 opened 4 months ago

rookie0607 commented 4 months ago

感谢您开源的优秀作品,有个问题想向您请教一下。从论文中看SpatialNet-small → SpatialNet-large 性能有比较大的提升,您是否尝试过更大参数量的SpatialNet?SpatialNet-large已经是上限了吗?

quancs commented 4 months ago

没有尝试过更大的参数量,性能应该还有提升空间的。

rookie0607 commented 4 months ago

@quancs 大佬,我无法复现您论文中中关于WHAMR!数据集的性能,能帮我看下哪里出了问题吗? 训练与评估数据采用https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/whamr_scripts.tar.gz 脚本模拟得到的,采用 min,8k版本。 训练脚本: python SharedTrainer.py fit \ --config=configs/SpatialNet.yaml \ --config=configs/datasets/whamr.yaml \ --model.channels=[0,1] \ --model.arch.dim_input=4 \ --model.arch.dim_output=4 \ --model.arch.num_freqs=129 \ --trainer.precision=bf16-mixed \ --model.compile=True \ --data.batch_size=[8,8] \ --trainer.devices=0,1,2,3, \ --trainer.max_epochs=250

SpatialNet.yaml: seed_everything: 2 trainer: gradient_clip_val: 5 gradient_clip_algorithm: norm devices: null accelerator: gpu strategy: auto sync_batchnorm: false precision: 32 model: arch: class_path: models.arch.SpatialNet.SpatialNet init_args:

dim_input: 12

  # dim_output: 4
  num_layers: 8 # 12 for large
  encoder_kernel_size: 5
  dim_hidden: 96 # 192 for large
  dim_ffn: 192 # 384 for large
  num_heads: 4
  dropout: [0, 0, 0]
  kernel_size: [5, 3]
  conv_groups: [8, 8]
  norms: ["LN", "LN", "GN", "LN", "LN", "LN"]
  dim_squeeze: 8 # 16 for large
  num_freqs: 129
  full_share: 0

channels: [0, 1, 2, 3, 4, 5] ref_channel: 0 stft: class_path: models.io.stft.STFT init_args: n_fft: 256 n_hop: 128 loss: class_path: models.io.loss.Loss init_args: loss_func: models.io.loss.neg_si_sdr pit: True norm: class_path: models.io.norm.Norm init_args: mode: frequency optimizer: [Adam, { lr: 0.002 }] lr_scheduler: [ExponentialLR, { gamma: 0.99 }] exp_name: exp metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] val_metric: loss

whamr.yaml: data: class_path: data_loaders.whamr.WHAMRDataModule init_args: whamr_dir: /home/data/en_train_data/wham/whamr_2ch version: min target: anechoic sample_rate: 8000 audio_time_len: [4.0, 4.0, null] batch_size: [2, 1]

训练了225epoch的测试结果为:

Method SISDR (dB) SDR (dB) NB-PESQ eSTOI
unproc. -6.1 -3.5 1.41 0.317
SpatialNet-small 11.8 13.1 2.93 0.826
SpatialNet-large 14.1 15.0 3.16 0.870
Our(small) 7.93 10.62 2.61 0.648
quancs commented 4 months ago

我看了下我之前实验的配置,和你的对比了下,有两个主要的区别:

rookie0607 commented 3 months ago

大佬, SpatialNet 对3个以及以上说话人重叠分离效果如何? @quancs

quancs commented 3 months ago

3个说话人同时说话的语音分离任务没有试过呢

rookie0607 commented 3 months ago

@quancs 大佬,SpatialNet有可能做成流式的吗?该如何去做呢;如果不能,您给给我推荐几个像SpatialNet强大的css模型吗,谢谢您!

quancs commented 3 months ago

@rookie0607 流式的版本即将开源呢,这是我们关于流式实现的论文 Multichannel Long-Term Streaming Neural Speech Enhancement for Static and Moving Speakers

rookie0607 commented 3 months ago

@rookie0607流式的版本即将到来呢,这是我们关于开源流式实现的论文Multichannel Long Term Streaming Neural Speech Improvement for Static and Moving Saturdays

太棒啦!非常期待!

quancs commented 3 months ago

@rookie0607 流式版本已开源 models/arch/OnlineSpatialNet.py 对应的配置为 configs/onlineSpatialNet.yaml

rookie0607 commented 3 months ago

@rookie0607 流式版本已开源 models/arch/OnlineSpatialNet.py 对应的配置为 configs/onlineSpatialNet.yaml

nice!我pull了最新的代码,但是在训练onlinespatialnet过程中遇到了一些问题,您能帮我看看吗?

python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \
  --config=configs/datasets/whamr.yaml \
  --model.channels=[0,1] \
  --model.arch.dim_input=4 \
  --model.arch.dim_output=4 \
  --model.arch.num_freqs=129 \
  --trainer.precision=16-mixed \
  --model.compile=True \
  --data.batch_size=[2,2] \
  --trainer.devices=0 \
  --trainer.max_epochs=100

erro log:

usage: SharedTrainer.py [options] fit [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING]....
SharedTrainer.py [options] fit: error: Parser key "model.arch": unsupported operand type(s) for |: '_UnionGenericAlias' and 'type'

python environment:

Package                  Version
------------------------ ------------
absl-py                  1.4.0
addict                   2.4.0
aiohttp                  3.8.5
aiosignal                1.3.1
aliyun-python-sdk-core   2.14.0
aliyun-python-sdk-kms    2.16.2
antlr4-python3-runtime   4.9.3
archspec                 0.2.1
async-timeout            4.0.3
attrs                    23.1.0
audioread                3.0.0
boltons                  23.0.0
Brotli                   1.0.9
cachetools               5.3.1
causal-conv1d            1.2.0.post2
certifi                  2024.2.2
cffi                     1.16.0
cfgv                     3.4.0
charset-normalizer       2.0.4
click                    8.1.7
cmake                    3.27.2
conda                    24.1.1
conda-libmamba-solver    24.1.0
conda-package-handling   2.2.0
conda_package_streaming  0.9.0
crcmod                   1.7
cryptography             42.0.5
datasets                 2.17.1
decorator                5.1.1
dill                     0.3.7
distlib                  0.3.7
distro                   1.8.0
docstring_parser         0.16
editdistance             0.8.1
einops                   0.7.0
fairseq2                 0.1.0
fairseq2n                0.1.0
filelock                 3.12.2
frozenlist               1.4.0
fsspec                   2023.6.0
funasr                   1.0.10
future                   1.0.0
gast                     0.5.4
google-auth              2.22.0
google-auth-oauthlib     1.0.0
grpcio                   1.57.0
huggingface-hub          0.20.3
hydra-core               1.3.2
identify                 2.5.27
idna                     3.4
importlib-metadata       6.8.0
importlib_resources      6.3.1
jaconv                   0.3.4
jamo                     0.4.1
jieba                    0.42.1
Jinja2                   3.1.2
jiwer                    3.0.2
jmespath                 0.10.0
joblib                   1.3.2
jsonargparse             4.17.0
jsonpatch                1.32
jsonpointer              2.1
kaldiio                  2.18.0
lazy_loader              0.3
libmambapy               1.5.6
librosa                  0.10.1
lightning-utilities      0.11.0
lit                      16.0.6
llvmlite                 0.40.1
lxml                     4.9.3
mamba-ssm                1.2.0.post1
Markdown                 3.4.4
markdown-it-py           3.0.0
MarkupSafe               2.1.3
mdurl                    0.1.2
menuinst                 2.0.2
mir-eval                 0.7
modelscope               1.12.0
mpmath                   1.3.0
msgpack                  1.0.5
multidict                6.0.4
multiprocess             0.70.15
mypy-extensions          1.0.0
networkx                 3.1
ninja                    1.11.1.1
nodeenv                  1.8.0
numba                    0.57.1
numpy                    1.24.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
omegaconf                2.3.0
oss2                     2.18.4
overrides                7.4.0
packaging                23.1
pandas                   2.0.3
pesq                     0.0.4
pillow                   10.2.0
pip                      23.3.1
platformdirs             3.10.0
pluggy                   1.0.0
pooch                    1.7.0
portalocker              2.7.0
pre-commit               3.3.3
protobuf                 4.24.1
psutil                   5.9.5
pyarrow                  13.0.0
pyarrow-hotfix           0.6
pyasn1                   0.5.0
pyasn1-modules           0.3.0
pycosat                  0.6.6
pycparser                2.21
pycryptodome             3.20.0
pydub                    0.25.1
Pygments                 2.17.2
pynndescent              0.5.11
pyre-extensions          0.0.30
PySocks                  1.7.1
pystoi                   0.4.1
python-dateutil          2.8.2
pytorch-lightning        2.0.0
pytorch-wpe              0.0.1
pytz                     2023.3
PyYAML                   6.0.1
rapidfuzz                2.13.7
regex                    2023.8.8
requests                 2.31.0
requests-oauthlib        1.3.1
rich                     13.7.1
rsa                      4.9
ruamel.yaml              0.17.21
ruamel.yaml.clib         0.2.8
sacrebleu                2.3.1
safetensors              0.4.2
scikit-learn             1.3.0
scipy                    1.11.2
seamless-communication   1.0.0
sentencepiece            0.2.0
setuptools               68.2.2
simplejson               3.19.2
six                      1.16.0
sortedcontainers         2.4.0
soundfile                0.12.1
soxr                     0.3.6
sympy                    1.12
tabulate                 0.9.0
tbb                      2021.8.0
tensorboard              2.14.0
tensorboard-data-server  0.7.1
threadpoolctl            3.2.0
tokenizers               0.15.2
tomli                    2.0.1
torch                    2.1.1+cu118
torch-complex            0.4.3
torchaudio               2.1.1+cu118
torcheval                0.0.6
torchmetrics             1.3.2
torchtnt                 0.2.0
torchvision              0.16.1+cu118
tqdm                     4.65.0
transformers             4.38.2
triton                   2.1.0
typeshed_client          2.5.1
typing_extensions        4.7.1
typing-inspect           0.9.0
tzdata                   2023.3
umap-learn               0.5.5
urllib3                  1.26.18
validators               0.23.2
virtualenv               20.24.3
Werkzeug                 2.3.7
wheel                    0.41.2
xxhash                   3.3.0
yapf                     0.40.2
yarl                     1.9.2
zipp                     3.16.2
zstandard                0.19.0

@quancs

quancs commented 3 months ago

不好意思,代码不太完整。刚刚更新修复了,下面的指令我这边测试好了,是可以训练的

python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \
    --config=configs/datasets/whamr.yaml \
    --model.channels=[0,1] \
    --model.arch.dim_input=4 \
    --model.arch.dim_output=4 \
    --model.arch.num_freqs=129 \
    --trainer.precision=16-mixed \
    --model.compile=true \ # set to false for mamba
    --data.batch_size=[2,2] \
    --trainer.devices=0, \
    --trainer.max_epochs=100 \
    --model.stft.n_fft=256 \
    --model.stft.n_hop=128
rookie0607 commented 3 months ago

不好意思,代码不太完整。刚刚更新修复了,下面的指令我这边测试好了,是可以训练的

python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \
    --config=configs/datasets/whamr.yaml \
    --model.channels=[0,1] \
    --model.arch.dim_input=4 \
    --model.arch.dim_output=4 \
    --model.arch.num_freqs=129 \
    --trainer.precision=16-mixed \
    --model.compile=true \ # set to false for mamba
    --data.batch_size=[2,2] \
    --trainer.devices=0, \
    --trainer.max_epochs=100 \
    --model.stft.n_fft=256 \
    --model.stft.n_hop=128

我pull了最新代码,按照您给的指令运行仍然得到了相同的错误。

quancs commented 3 months ago

@rookie0607 我这边刚刚确认了一下,指令和代码在我电脑上确实都没有问题。你可以检查下包的版本这些是不是需要升级或者怎么样

quancs commented 3 months ago

你得到的错误提示只有那一行吗?感觉看着像是jsonargparse包的问题

quancs commented 3 months ago

这个是我安装的包以及版本,供你参考下

a.txt

rookie0607 commented 3 months ago

你得到的错误提示只有那一行吗?感觉看着像是jsonargparse包的问题

@quancs 应该不是环境的问题,我按照您提供的a.txt对齐了所有的安装包版本,仍得到报错:

usage: SharedTrainer.py [options] fit [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING]
                                      [--trainer CONFIG] [--trainer.accelerator.help CLASS_PATH_OR_NAME]
                                      [--trainer.accelerator ACCELERATOR] [--trainer.strategy.help CLASS_PATH_OR_NAME]
                                      [--trainer.strategy STRATEGY] [--trainer.devices DEVICES] [--trainer.num_nodes NUM_NODES]
                                      [--trainer.precision PRECISION] [--trainer.logger.help CLASS_PATH_OR_NAME]
                                      [--trainer.logger LOGGER] [--trainer.callbacks.help CLASS_PATH_OR_NAME]
                                      [--trainer.callbacks CALLBACKS] [--trainer.fast_dev_run FAST_DEV_RUN]
                                      [--trainer.max_epochs MAX_EPOCHS] [--trainer.min_epochs MIN_EPOCHS]
                                      [--trainer.max_steps MAX_STEPS] [--trainer.min_steps MIN_STEPS]
                                      [--trainer.max_time MAX_TIME] [--trainer.limit_train_batches LIMIT_TRAIN_BATCHES]
                                      [--trainer.limit_val_batches LIMIT_VAL_BATCHES]
                                      [--trainer.limit_test_batches LIMIT_TEST_BATCHES]
                                      [--trainer.limit_predict_batches LIMIT_PREDICT_BATCHES]
                                      [--trainer.overfit_batches OVERFIT_BATCHES]
                                      [--trainer.val_check_interval VAL_CHECK_INTERVAL]
                                      [--trainer.check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH]
                                      [--trainer.num_sanity_val_steps NUM_SANITY_VAL_STEPS]
                                      [--trainer.log_every_n_steps LOG_EVERY_N_STEPS]
                                      [--trainer.enable_checkpointing {true,false,null}]
                                      [--trainer.enable_progress_bar {true,false,null}]
                                      [--trainer.enable_model_summary {true,false,null}]
                                      [--trainer.accumulate_grad_batches ACCUMULATE_GRAD_BATCHES]
                                      [--trainer.gradient_clip_val GRADIENT_CLIP_VAL]
                                      [--trainer.gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM]
                                      [--trainer.deterministic DETERMINISTIC] [--trainer.benchmark {true,false,null}]
                                      [--trainer.inference_mode {true,false}] [--trainer.use_distributed_sampler {true,false}]
                                      [--trainer.profiler.help CLASS_PATH_OR_NAME] [--trainer.profiler PROFILER]
                                      [--trainer.detect_anomaly {true,false}] [--trainer.barebones {true,false}]
                                      [--trainer.plugins.help CLASS_PATH_OR_NAME] [--trainer.plugins PLUGINS]
                                      [--trainer.sync_batchnorm {true,false}]
                                      [--trainer.reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS]
                                      [--trainer.default_root_dir DEFAULT_ROOT_DIR] [--model CONFIG]
                                      [--model.arch.help CLASS_PATH_OR_NAME] --model.arch ARCH --model.channels CHANNELS
                                      --model.ref_channel REF_CHANNEL [--model.stft.help CLASS_PATH_OR_NAME]
                                      [--model.stft STFT] [--model.norm.help CLASS_PATH_OR_NAME] [--model.norm NORM]
                                      [--model.loss.help CLASS_PATH_OR_NAME] [--model.loss LOSS] [--model.optimizer [ITEM,...]]
                                      [--model.lr_scheduler LR_SCHEDULER] [--model.metrics METRICS] [--model.mchunk MCHUNK]
                                      [--model.val_metric VAL_METRIC] [--model.write_examples WRITE_EXAMPLES]
                                      [--model.ensemble ENSEMBLE] [--model.compile {true,false}] [--model.exp_name EXP_NAME]
                                      [--model.reset RESET] [--data.help CLASS_PATH_OR_NAME]
                                      --data CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE [--early_stopping CONFIG]
                                      [--early_stopping.enable {true,false}] [--early_stopping.monitor MONITOR]
                                      [--early_stopping.min_delta MIN_DELTA] [--early_stopping.patience PATIENCE]
                                      [--early_stopping.verbose {true,false}] [--early_stopping.mode MODE]
                                      [--early_stopping.strict {true,false}] [--early_stopping.check_finite {true,false}]
                                      [--early_stopping.stopping_threshold STOPPING_THRESHOLD]
                                      [--early_stopping.divergence_threshold DIVERGENCE_THRESHOLD]
                                      [--early_stopping.check_on_train_epoch_end {true,false,null}]
                                      [--early_stopping.log_rank_zero_only {true,false}] [--model_checkpoint CONFIG]
                                      [--model_checkpoint.dirpath DIRPATH] [--model_checkpoint.filename FILENAME]
                                      [--model_checkpoint.monitor MONITOR] [--model_checkpoint.verbose {true,false}]
                                      [--model_checkpoint.save_last {true,false,null}] [--model_checkpoint.save_top_k SAVE_TOP_K]
                                      [--model_checkpoint.save_weights_only {true,false}] [--model_checkpoint.mode MODE]
                                      [--model_checkpoint.auto_insert_metric_name {true,false}]
                                      [--model_checkpoint.every_n_train_steps EVERY_N_TRAIN_STEPS]
                                      [--model_checkpoint.train_time_interval TRAIN_TIME_INTERVAL]
                                      [--model_checkpoint.every_n_epochs EVERY_N_EPOCHS]
                                      [--model_checkpoint.save_on_train_epoch_end {true,false,null}]
                                      [--model_checkpoint.enable_version_counter {true,false}] [--progress_bar CONFIG]
                                      [--progress_bar.refresh_rate REFRESH_RATE] [--progress_bar.leave {true,false}]
                                      [--progress_bar.theme CONFIG] [--progress_bar.theme.description.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.description DESCRIPTION]
                                      [--progress_bar.theme.progress_bar.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.progress_bar PROGRESS_BAR]
                                      [--progress_bar.theme.progress_bar_finished.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.progress_bar_finished PROGRESS_BAR_FINISHED]
                                      [--progress_bar.theme.progress_bar_pulse.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.progress_bar_pulse PROGRESS_BAR_PULSE]
                                      [--progress_bar.theme.batch_progress.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.batch_progress BATCH_PROGRESS]
                                      [--progress_bar.theme.time.help CLASS_PATH_OR_NAME] [--progress_bar.theme.time TIME]
                                      [--progress_bar.theme.processing_speed.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.processing_speed PROCESSING_SPEED]
                                      [--progress_bar.theme.metrics.help CLASS_PATH_OR_NAME]
                                      [--progress_bar.theme.metrics METRICS]
                                      [--progress_bar.theme.metrics_text_delimiter METRICS_TEXT_DELIMITER]
                                      [--progress_bar.theme.metrics_format METRICS_FORMAT]
                                      [--progress_bar.console_kwargs CONSOLE_KWARGS] [--learning_rate_monitor CONFIG]
                                      [--learning_rate_monitor.logging_interval LOGGING_INTERVAL]
                                      [--learning_rate_monitor.log_momentum {true,false}]
                                      [--learning_rate_monitor.log_weight_decay {true,false}] [--model_summary CONFIG]
                                      [--model_summary.max_depth MAX_DEPTH] [--optimizer.help CLASS_PATH_OR_NAME]
                                      [--optimizer CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE]
                                      [--lr_scheduler.help CLASS_PATH_OR_NAME]
                                      [--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] [--ckpt_path CKPT_PATH]
error: Parser key "model.arch":
  unsupported operand type(s) for |: '_UnionGenericAlias' and 'type'

另外,我把配置文件修改为SpatialNet.yaml是可以正常运行的。

quancs commented 3 months ago

那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试

rookie0607 commented 3 months ago

那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试

好的,我尝试下,谢谢大佬

rookie0607 commented 3 months ago

那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试

大佬,升级Python 3.10.13可以运行哈哈哈。 顺便问一句,这个 warning不影响训练吧。

[2024-03-24 17:23:41,269] [6/31] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-03-24 17:23:41,269] [6/31] torch._dynamo.variables.higher_order_ops: [ERROR] inline in skipfiles: MambaInnerFn.forward  | decorate_fwd /home/miniconda3/envs/nbss/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py
quancs commented 3 months ago

不清楚的哈,我这边没看到这个