Open rookie0607 opened 4 months ago
没有尝试过更大的参数量,性能应该还有提升空间的。
@quancs 大佬,我无法复现您论文中中关于WHAMR!数据集的性能,能帮我看下哪里出了问题吗? 训练与评估数据采用https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/whamr_scripts.tar.gz 脚本模拟得到的,采用 min,8k版本。 训练脚本: python SharedTrainer.py fit \ --config=configs/SpatialNet.yaml \ --config=configs/datasets/whamr.yaml \ --model.channels=[0,1] \ --model.arch.dim_input=4 \ --model.arch.dim_output=4 \ --model.arch.num_freqs=129 \ --trainer.precision=bf16-mixed \ --model.compile=True \ --data.batch_size=[8,8] \ --trainer.devices=0,1,2,3, \ --trainer.max_epochs=250
SpatialNet.yaml: seed_everything: 2 trainer: gradient_clip_val: 5 gradient_clip_algorithm: norm devices: null accelerator: gpu strategy: auto sync_batchnorm: false precision: 32 model: arch: class_path: models.arch.SpatialNet.SpatialNet init_args:
# dim_output: 4
num_layers: 8 # 12 for large
encoder_kernel_size: 5
dim_hidden: 96 # 192 for large
dim_ffn: 192 # 384 for large
num_heads: 4
dropout: [0, 0, 0]
kernel_size: [5, 3]
conv_groups: [8, 8]
norms: ["LN", "LN", "GN", "LN", "LN", "LN"]
dim_squeeze: 8 # 16 for large
num_freqs: 129
full_share: 0
channels: [0, 1, 2, 3, 4, 5] ref_channel: 0 stft: class_path: models.io.stft.STFT init_args: n_fft: 256 n_hop: 128 loss: class_path: models.io.loss.Loss init_args: loss_func: models.io.loss.neg_si_sdr pit: True norm: class_path: models.io.norm.Norm init_args: mode: frequency optimizer: [Adam, { lr: 0.002 }] lr_scheduler: [ExponentialLR, { gamma: 0.99 }] exp_name: exp metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] val_metric: loss
whamr.yaml: data: class_path: data_loaders.whamr.WHAMRDataModule init_args: whamr_dir: /home/data/en_train_data/wham/whamr_2ch version: min target: anechoic sample_rate: 8000 audio_time_len: [4.0, 4.0, null] batch_size: [2, 1]
训练了225epoch的测试结果为:
Method | SISDR (dB) | SDR (dB) | NB-PESQ | eSTOI |
---|---|---|---|---|
unproc. | -6.1 | -3.5 | 1.41 | 0.317 |
SpatialNet-small | 11.8 | 13.1 | 2.93 | 0.826 |
SpatialNet-large | 14.1 | 15.0 | 3.16 | 0.870 |
Our(small) | 7.93 | 10.62 | 2.61 | 0.648 |
我看了下我之前实验的配置,和你的对比了下,有两个主要的区别:
大佬, SpatialNet 对3个以及以上说话人重叠分离效果如何? @quancs
3个说话人同时说话的语音分离任务没有试过呢
@quancs 大佬,SpatialNet有可能做成流式的吗?该如何去做呢;如果不能,您给给我推荐几个像SpatialNet强大的css模型吗,谢谢您!
@rookie0607 流式的版本即将开源呢,这是我们关于流式实现的论文 Multichannel Long-Term Streaming Neural Speech Enhancement for Static and Moving Speakers
@rookie0607流式的版本即将到来呢,这是我们关于开源流式实现的论文Multichannel Long Term Streaming Neural Speech Improvement for Static and Moving Saturdays
太棒啦!非常期待!
@rookie0607 流式版本已开源 models/arch/OnlineSpatialNet.py 对应的配置为 configs/onlineSpatialNet.yaml
@rookie0607 流式版本已开源 models/arch/OnlineSpatialNet.py 对应的配置为 configs/onlineSpatialNet.yaml
nice!我pull了最新的代码,但是在训练onlinespatialnet过程中遇到了一些问题,您能帮我看看吗?
python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \
--config=configs/datasets/whamr.yaml \
--model.channels=[0,1] \
--model.arch.dim_input=4 \
--model.arch.dim_output=4 \
--model.arch.num_freqs=129 \
--trainer.precision=16-mixed \
--model.compile=True \
--data.batch_size=[2,2] \
--trainer.devices=0 \
--trainer.max_epochs=100
erro log:
usage: SharedTrainer.py [options] fit [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING]....
SharedTrainer.py [options] fit: error: Parser key "model.arch": unsupported operand type(s) for |: '_UnionGenericAlias' and 'type'
python environment:
Package Version
------------------------ ------------
absl-py 1.4.0
addict 2.4.0
aiohttp 3.8.5
aiosignal 1.3.1
aliyun-python-sdk-core 2.14.0
aliyun-python-sdk-kms 2.16.2
antlr4-python3-runtime 4.9.3
archspec 0.2.1
async-timeout 4.0.3
attrs 23.1.0
audioread 3.0.0
boltons 23.0.0
Brotli 1.0.9
cachetools 5.3.1
causal-conv1d 1.2.0.post2
certifi 2024.2.2
cffi 1.16.0
cfgv 3.4.0
charset-normalizer 2.0.4
click 8.1.7
cmake 3.27.2
conda 24.1.1
conda-libmamba-solver 24.1.0
conda-package-handling 2.2.0
conda_package_streaming 0.9.0
crcmod 1.7
cryptography 42.0.5
datasets 2.17.1
decorator 5.1.1
dill 0.3.7
distlib 0.3.7
distro 1.8.0
docstring_parser 0.16
editdistance 0.8.1
einops 0.7.0
fairseq2 0.1.0
fairseq2n 0.1.0
filelock 3.12.2
frozenlist 1.4.0
fsspec 2023.6.0
funasr 1.0.10
future 1.0.0
gast 0.5.4
google-auth 2.22.0
google-auth-oauthlib 1.0.0
grpcio 1.57.0
huggingface-hub 0.20.3
hydra-core 1.3.2
identify 2.5.27
idna 3.4
importlib-metadata 6.8.0
importlib_resources 6.3.1
jaconv 0.3.4
jamo 0.4.1
jieba 0.42.1
Jinja2 3.1.2
jiwer 3.0.2
jmespath 0.10.0
joblib 1.3.2
jsonargparse 4.17.0
jsonpatch 1.32
jsonpointer 2.1
kaldiio 2.18.0
lazy_loader 0.3
libmambapy 1.5.6
librosa 0.10.1
lightning-utilities 0.11.0
lit 16.0.6
llvmlite 0.40.1
lxml 4.9.3
mamba-ssm 1.2.0.post1
Markdown 3.4.4
markdown-it-py 3.0.0
MarkupSafe 2.1.3
mdurl 0.1.2
menuinst 2.0.2
mir-eval 0.7
modelscope 1.12.0
mpmath 1.3.0
msgpack 1.0.5
multidict 6.0.4
multiprocess 0.70.15
mypy-extensions 1.0.0
networkx 3.1
ninja 1.11.1.1
nodeenv 1.8.0
numba 0.57.1
numpy 1.24.4
nvidia-cublas-cu11 11.10.3.66
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11 8.5.0.96
nvidia-cufft-cu11 10.9.0.58
nvidia-curand-cu11 10.2.10.91
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusparse-cu11 11.7.4.91
nvidia-nccl-cu11 2.14.3
nvidia-nvtx-cu11 11.7.91
oauthlib 3.2.2
omegaconf 2.3.0
oss2 2.18.4
overrides 7.4.0
packaging 23.1
pandas 2.0.3
pesq 0.0.4
pillow 10.2.0
pip 23.3.1
platformdirs 3.10.0
pluggy 1.0.0
pooch 1.7.0
portalocker 2.7.0
pre-commit 3.3.3
protobuf 4.24.1
psutil 5.9.5
pyarrow 13.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.0
pyasn1-modules 0.3.0
pycosat 0.6.6
pycparser 2.21
pycryptodome 3.20.0
pydub 0.25.1
Pygments 2.17.2
pynndescent 0.5.11
pyre-extensions 0.0.30
PySocks 1.7.1
pystoi 0.4.1
python-dateutil 2.8.2
pytorch-lightning 2.0.0
pytorch-wpe 0.0.1
pytz 2023.3
PyYAML 6.0.1
rapidfuzz 2.13.7
regex 2023.8.8
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.1
rsa 4.9
ruamel.yaml 0.17.21
ruamel.yaml.clib 0.2.8
sacrebleu 2.3.1
safetensors 0.4.2
scikit-learn 1.3.0
scipy 1.11.2
seamless-communication 1.0.0
sentencepiece 0.2.0
setuptools 68.2.2
simplejson 3.19.2
six 1.16.0
sortedcontainers 2.4.0
soundfile 0.12.1
soxr 0.3.6
sympy 1.12
tabulate 0.9.0
tbb 2021.8.0
tensorboard 2.14.0
tensorboard-data-server 0.7.1
threadpoolctl 3.2.0
tokenizers 0.15.2
tomli 2.0.1
torch 2.1.1+cu118
torch-complex 0.4.3
torchaudio 2.1.1+cu118
torcheval 0.0.6
torchmetrics 1.3.2
torchtnt 0.2.0
torchvision 0.16.1+cu118
tqdm 4.65.0
transformers 4.38.2
triton 2.1.0
typeshed_client 2.5.1
typing_extensions 4.7.1
typing-inspect 0.9.0
tzdata 2023.3
umap-learn 0.5.5
urllib3 1.26.18
validators 0.23.2
virtualenv 20.24.3
Werkzeug 2.3.7
wheel 0.41.2
xxhash 3.3.0
yapf 0.40.2
yarl 1.9.2
zipp 3.16.2
zstandard 0.19.0
@quancs
不好意思,代码不太完整。刚刚更新修复了,下面的指令我这边测试好了,是可以训练的
python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \
--config=configs/datasets/whamr.yaml \
--model.channels=[0,1] \
--model.arch.dim_input=4 \
--model.arch.dim_output=4 \
--model.arch.num_freqs=129 \
--trainer.precision=16-mixed \
--model.compile=true \ # set to false for mamba
--data.batch_size=[2,2] \
--trainer.devices=0, \
--trainer.max_epochs=100 \
--model.stft.n_fft=256 \
--model.stft.n_hop=128
不好意思,代码不太完整。刚刚更新修复了,下面的指令我这边测试好了,是可以训练的
python SharedTrainer.py fit --config=configs/onlineSpatialNet.yaml \ --config=configs/datasets/whamr.yaml \ --model.channels=[0,1] \ --model.arch.dim_input=4 \ --model.arch.dim_output=4 \ --model.arch.num_freqs=129 \ --trainer.precision=16-mixed \ --model.compile=true \ # set to false for mamba --data.batch_size=[2,2] \ --trainer.devices=0, \ --trainer.max_epochs=100 \ --model.stft.n_fft=256 \ --model.stft.n_hop=128
我pull了最新代码,按照您给的指令运行仍然得到了相同的错误。
@rookie0607 我这边刚刚确认了一下,指令和代码在我电脑上确实都没有问题。你可以检查下包的版本这些是不是需要升级或者怎么样
你得到的错误提示只有那一行吗?感觉看着像是jsonargparse包的问题
你得到的错误提示只有那一行吗?感觉看着像是jsonargparse包的问题
@quancs 应该不是环境的问题,我按照您提供的a.txt
对齐了所有的安装包版本,仍得到报错:
usage: SharedTrainer.py [options] fit [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING]
[--trainer CONFIG] [--trainer.accelerator.help CLASS_PATH_OR_NAME]
[--trainer.accelerator ACCELERATOR] [--trainer.strategy.help CLASS_PATH_OR_NAME]
[--trainer.strategy STRATEGY] [--trainer.devices DEVICES] [--trainer.num_nodes NUM_NODES]
[--trainer.precision PRECISION] [--trainer.logger.help CLASS_PATH_OR_NAME]
[--trainer.logger LOGGER] [--trainer.callbacks.help CLASS_PATH_OR_NAME]
[--trainer.callbacks CALLBACKS] [--trainer.fast_dev_run FAST_DEV_RUN]
[--trainer.max_epochs MAX_EPOCHS] [--trainer.min_epochs MIN_EPOCHS]
[--trainer.max_steps MAX_STEPS] [--trainer.min_steps MIN_STEPS]
[--trainer.max_time MAX_TIME] [--trainer.limit_train_batches LIMIT_TRAIN_BATCHES]
[--trainer.limit_val_batches LIMIT_VAL_BATCHES]
[--trainer.limit_test_batches LIMIT_TEST_BATCHES]
[--trainer.limit_predict_batches LIMIT_PREDICT_BATCHES]
[--trainer.overfit_batches OVERFIT_BATCHES]
[--trainer.val_check_interval VAL_CHECK_INTERVAL]
[--trainer.check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH]
[--trainer.num_sanity_val_steps NUM_SANITY_VAL_STEPS]
[--trainer.log_every_n_steps LOG_EVERY_N_STEPS]
[--trainer.enable_checkpointing {true,false,null}]
[--trainer.enable_progress_bar {true,false,null}]
[--trainer.enable_model_summary {true,false,null}]
[--trainer.accumulate_grad_batches ACCUMULATE_GRAD_BATCHES]
[--trainer.gradient_clip_val GRADIENT_CLIP_VAL]
[--trainer.gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM]
[--trainer.deterministic DETERMINISTIC] [--trainer.benchmark {true,false,null}]
[--trainer.inference_mode {true,false}] [--trainer.use_distributed_sampler {true,false}]
[--trainer.profiler.help CLASS_PATH_OR_NAME] [--trainer.profiler PROFILER]
[--trainer.detect_anomaly {true,false}] [--trainer.barebones {true,false}]
[--trainer.plugins.help CLASS_PATH_OR_NAME] [--trainer.plugins PLUGINS]
[--trainer.sync_batchnorm {true,false}]
[--trainer.reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS]
[--trainer.default_root_dir DEFAULT_ROOT_DIR] [--model CONFIG]
[--model.arch.help CLASS_PATH_OR_NAME] --model.arch ARCH --model.channels CHANNELS
--model.ref_channel REF_CHANNEL [--model.stft.help CLASS_PATH_OR_NAME]
[--model.stft STFT] [--model.norm.help CLASS_PATH_OR_NAME] [--model.norm NORM]
[--model.loss.help CLASS_PATH_OR_NAME] [--model.loss LOSS] [--model.optimizer [ITEM,...]]
[--model.lr_scheduler LR_SCHEDULER] [--model.metrics METRICS] [--model.mchunk MCHUNK]
[--model.val_metric VAL_METRIC] [--model.write_examples WRITE_EXAMPLES]
[--model.ensemble ENSEMBLE] [--model.compile {true,false}] [--model.exp_name EXP_NAME]
[--model.reset RESET] [--data.help CLASS_PATH_OR_NAME]
--data CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE [--early_stopping CONFIG]
[--early_stopping.enable {true,false}] [--early_stopping.monitor MONITOR]
[--early_stopping.min_delta MIN_DELTA] [--early_stopping.patience PATIENCE]
[--early_stopping.verbose {true,false}] [--early_stopping.mode MODE]
[--early_stopping.strict {true,false}] [--early_stopping.check_finite {true,false}]
[--early_stopping.stopping_threshold STOPPING_THRESHOLD]
[--early_stopping.divergence_threshold DIVERGENCE_THRESHOLD]
[--early_stopping.check_on_train_epoch_end {true,false,null}]
[--early_stopping.log_rank_zero_only {true,false}] [--model_checkpoint CONFIG]
[--model_checkpoint.dirpath DIRPATH] [--model_checkpoint.filename FILENAME]
[--model_checkpoint.monitor MONITOR] [--model_checkpoint.verbose {true,false}]
[--model_checkpoint.save_last {true,false,null}] [--model_checkpoint.save_top_k SAVE_TOP_K]
[--model_checkpoint.save_weights_only {true,false}] [--model_checkpoint.mode MODE]
[--model_checkpoint.auto_insert_metric_name {true,false}]
[--model_checkpoint.every_n_train_steps EVERY_N_TRAIN_STEPS]
[--model_checkpoint.train_time_interval TRAIN_TIME_INTERVAL]
[--model_checkpoint.every_n_epochs EVERY_N_EPOCHS]
[--model_checkpoint.save_on_train_epoch_end {true,false,null}]
[--model_checkpoint.enable_version_counter {true,false}] [--progress_bar CONFIG]
[--progress_bar.refresh_rate REFRESH_RATE] [--progress_bar.leave {true,false}]
[--progress_bar.theme CONFIG] [--progress_bar.theme.description.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.description DESCRIPTION]
[--progress_bar.theme.progress_bar.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.progress_bar PROGRESS_BAR]
[--progress_bar.theme.progress_bar_finished.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.progress_bar_finished PROGRESS_BAR_FINISHED]
[--progress_bar.theme.progress_bar_pulse.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.progress_bar_pulse PROGRESS_BAR_PULSE]
[--progress_bar.theme.batch_progress.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.batch_progress BATCH_PROGRESS]
[--progress_bar.theme.time.help CLASS_PATH_OR_NAME] [--progress_bar.theme.time TIME]
[--progress_bar.theme.processing_speed.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.processing_speed PROCESSING_SPEED]
[--progress_bar.theme.metrics.help CLASS_PATH_OR_NAME]
[--progress_bar.theme.metrics METRICS]
[--progress_bar.theme.metrics_text_delimiter METRICS_TEXT_DELIMITER]
[--progress_bar.theme.metrics_format METRICS_FORMAT]
[--progress_bar.console_kwargs CONSOLE_KWARGS] [--learning_rate_monitor CONFIG]
[--learning_rate_monitor.logging_interval LOGGING_INTERVAL]
[--learning_rate_monitor.log_momentum {true,false}]
[--learning_rate_monitor.log_weight_decay {true,false}] [--model_summary CONFIG]
[--model_summary.max_depth MAX_DEPTH] [--optimizer.help CLASS_PATH_OR_NAME]
[--optimizer CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE]
[--lr_scheduler.help CLASS_PATH_OR_NAME]
[--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] [--ckpt_path CKPT_PATH]
error: Parser key "model.arch":
unsupported operand type(s) for |: '_UnionGenericAlias' and 'type'
另外,我把配置文件修改为SpatialNet.yaml
是可以正常运行的。
那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试
那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试
好的,我尝试下,谢谢大佬
那就不清楚了,我这边是能够正常运行的。你可以试试更换python版本(我的python版本是 Python 3.10.13),或者直接手动创建 pytorch-lightning 的 Trainer、LightningModel、DataModule来启动训练和测试
大佬,升级Python 3.10.13
可以运行哈哈哈。
顺便问一句,这个 warning不影响训练吧。
[2024-03-24 17:23:41,269] [6/31] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-03-24 17:23:41,269] [6/31] torch._dynamo.variables.higher_order_ops: [ERROR] inline in skipfiles: MambaInnerFn.forward | decorate_fwd /home/miniconda3/envs/nbss/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py
不清楚的哈,我这边没看到这个
感谢您开源的优秀作品,有个问题想向您请教一下。从论文中看SpatialNet-small → SpatialNet-large 性能有比较大的提升,您是否尝试过更大参数量的SpatialNet?SpatialNet-large已经是上限了吗?