Closed wendongj closed 1 year ago
Thank you for your attention to our work. To inference a real-recorded audio, you can
give run=False
to NBSSCLI
:
https://github.com/Audio-WestlakeU/NBSS/blob/2e0c353df436a903335e696d43e032b4d075cedd/NBSSCLI.py#L145-L152
then, you can obtain the model, and inference any audios by calling the predict_step
of the model:
model = cli.model
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
the corresponding command is the similar with fit
or test
, but without specifying fit
and test
:
python NBSSCLI.py --config=logs/NBSS/version_x/config.yaml \
--ckpt_path=logs/NBSS/version_x/checkpoints/epochY_neg_si_sdrZ.ckpt \
--trainer.devices=0, \
thanks for your reply, when I direct use the "the corresponding command", it has erros: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch31_neg_si_sdr-15.4105.ckpt --trainer.devices=0, did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,
did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,
yes. you should change the code a little bit first. In NBSS/NBSSCLI.py
:
cli = NBSSCLI(
NBSS,
pl.LightningDataModule,
save_config_callback=SaveConfigCallback,
save_config_kwargs={'overwrite': True},
subclass_mode_data=True,
run=False, # add this line
)
# add following lines
model = cli.model
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
You can also 1) create a instance of NBSS module by using the arguments in config.yaml
, 2) manually load the weights like a normal torch.Module
, 3) call the predict_step
of the NBSS module.
did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,
yes. you should change the code a little bit first. In
NBSS/NBSSCLI.py
:cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_callback=SaveConfigCallback, save_config_kwargs={'overwrite': True}, subclass_mode_data=True, run=False, # add this line ) # add following lines model = cli.model batch_size, num_chn, length = 1, 2, 64000 audio = torch.rand((batch_size, num_chn, length)) preds = model.predict_step(audio)
thanks, I will try, and, btw, your profile photo is handsome.
your profile photo is handsome.
Thank you ^_^
your profile photo is handsome.
Thank you ^_^
really sorry to disturb you again at night, I do as follows: first, I change NBSSCLI python file as following: cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_callback=SaveConfigCallback, save_config_kwargs={'overwrite': True}, subclass_mode_data=True, run=False, # add this line )
model = cli.model
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
then I runn the command: python NBSSCLI.py --config=logs/NBSS/version_8/config.yaml --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt --trainer.devices=0,
the errors are: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt, I tried many times and make sure the ckpt file is right, when I remove the check_path from the command, it have errors: NBSSCLI.py: error: 'Configuration check failed :: No action for destination key "ckpt_path" to check its value, seems the ckpt_path is really need, but when I add the ckpt_path to the command, it comes errors.
Sorry, it's my mistake, the given command seems not work in the new version. But, anyway, you can still inference the audio following:
You can also 1) create a instance of NBSS module by using the arguments in
config.yaml
, 2) manually load the weights like a normaltorch.Module
, 3) call thepredict_step
of the NBSS module.
many thanks, I will try that, ^_^
wait me a few minutes. I will write a short example for this.
NBSS
oh, I will wait you at my seat until tomorrow, it's really my honor to see your examples and really thanks
I verified that the following code works, and remove any arguments in config.yaml if them cause errors.
# CMD: python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
cli = NBSSCLI(
NBSS,
pl.LightningDataModule,
save_config_kwargs={'overwrite': True},
save_config_callback=SaveConfigCallback,
subclass_mode_data=True,
run=False, # add this line
)
model = cli.model
# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()
# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)
I verified that the following code works.
# CMD: python NBSSCLI.py --config logs/NBSS/version_x/config.yaml if __name__ == '__main__': cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_kwargs={'overwrite': True}, save_config_callback=SaveConfigCallback, subclass_mode_data=True, run=False, # add this line ) model = cli.model # load weights from checkpoint ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt' data = torch.load(ckpt, map_location='cpu') model.on_load_checkpoint(data) model.load_state_dict(data['state_dict']) model.eval() # inference an audio batch_size, num_chn, length = 1, 2, 64000 audio = torch.rand((batch_size, num_chn, length)) preds = model.predict_step(audio) print(preds.shape)
really thanks, the outside is raining, do not go home too late,
Another way:
from models.NBSS import NBSS
import torch
model=NBSS( # find the arguments from config.yaml.
arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
io=...,
...
)
# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()
# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)
Another way:
from models.NBSS import NBSS import torch model=NBSS( # find the arguments from config.yaml. arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk io=..., ... ) # load weights from checkpoint ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt' data = torch.load(ckpt, map_location='cpu') model.on_load_checkpoint(data) model.load_state_dict(data['state_dict']) model.eval() # inference an audio batch_size, num_chn, length = 1, 2, 64000 audio = torch.rand((batch_size, num_chn, length)) preds = model.predict_step(audio) print(preds.shape)
oh, both ways I will try
really thanks, the outside is raining, do not go home too late,
BRO, WHO ARE YOU?
I am an engineer and mainly work on speech enhancement and echo cancellation, recently prepare to do multi-channel speech separation, from FT-JNF paper find your team paper.
Sorry, I thought you are one of my friends, and played with me...
I verified that the following code works, and remove any arguments in config.yaml if them cause errors.
# CMD: python NBSSCLI.py --config logs/NBSS/version_x/config.yaml if __name__ == '__main__': cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_kwargs={'overwrite': True}, save_config_callback=SaveConfigCallback, subclass_mode_data=True, run=False, # add this line ) model = cli.model # load weights from checkpoint ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt' data = torch.load(ckpt, map_location='cpu') model.on_load_checkpoint(data) model.load_state_dict(data['state_dict']) model.eval() # inference an audio batch_size, num_chn, length = 1, 2, 64000 audio = torch.rand((batch_size, num_chn, length)) preds = model.predict_step(audio) print(preds.shape)
when I use the changed code in NBSSCLI.py, it comes the following error, but, in your side, it has no errors, it is strange usage: NBSSCLI.py [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING] [--trainer CONFIG] [--trainer.accelerator.help CLASS_PATH_OR_NAME] [--trainer.accelerator ACCELERATOR] [--trainer.strategy.help CLASS_PATH_OR_NAME] [--trainer.strategy STRATEGY] [--trainer.devices DEVICES] [--trainer.num_nodes NUM_NODES] [--trainer.precision PRECISION] [--trainer.logger.help CLASS_PATH_OR_NAME] [--trainer.logger LOGGER] [--trainer.callbacks.help CLASS_PATH_OR_NAME] [--trainer.callbacks CALLBACKS] [--trainer.fast_dev_run FAST_DEV_RUN] [--trainer.max_epochs MAX_EPOCHS] [--trainer.min_epochs MIN_EPOCHS] [--trainer.max_steps MAX_STEPS] [--trainer.min_steps MIN_STEPS] [--trainer.max_time MAX_TIME] [--trainer.limit_train_batches LIMIT_TRAIN_BATCHES] [--trainer.limit_val_batches LIMIT_VAL_BATCHES] [--trainer.limit_test_batches LIMIT_TEST_BATCHES] [--trainer.limit_predict_batches LIMIT_PREDICT_BATCHES] [--trainer.overfit_batches OVERFIT_BATCHES] [--trainer.val_check_interval VAL_CHECK_INTERVAL] [--trainer.check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH] [--trainer.num_sanity_val_steps NUM_SANITY_VAL_STEPS] [--trainer.log_every_n_steps LOG_EVERY_N_STEPS] [--trainer.enable_checkpointing {true,false,null}] [--trainer.enable_progress_bar {true,false,null}] [--trainer.enable_model_summary {true,false,null}] [--trainer.accumulate_grad_batches ACCUMULATE_GRAD_BATCHES] [--trainer.gradient_clip_val GRADIENT_CLIP_VAL] [--trainer.gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM] [--trainer.deterministic DETERMINISTIC] [--trainer.benchmark {true,false,null}] [--trainer.inference_mode {true,false}] [--trainer.use_distributed_sampler {true,false}] [--trainer.profiler.help CLASS_PATH_OR_NAME] [--trainer.profiler PROFILER] [--trainer.detect_anomaly {true,false}] [--trainer.barebones {true,false}] [--trainer.plugins.help CLASS_PATH_OR_NAME] [--trainer.plugins PLUGINS] [--trainer.sync_batchnorm {true,false}] [--trainer.reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS] [--trainer.default_root_dir DEFAULT_ROOT_DIR] [--model CONFIG] [--model.arch.help CLASS_PATH_OR_NAME] --model.arch ARCH [--model.io.help CLASS_PATH_OR_NAME] --model.io IO [--model.speaker_num SPEAKER_NUM] [--model.ref_channel REF_CHANNEL] [--model.channels CHANNELS] [--model.learning_rate LEARNING_RATE] [--model.optimizer OPTIMIZER] [--model.optimizer_kwargs OPTIMIZER_KWARGS] [--model.lr_scheduler LR_SCHEDULER] [--model.lr_scheduler_kwargs LR_SCHEDULER_KWARGS] [--model.exp_name EXP_NAME] [--model.metrics METRICS] [--data.help CLASS_PATH_OR_NAME] --data CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE [--progress_bar CONFIG] [--progress_bar.refresh_rate REFRESH_RATE] [--progress_bar.leave {true,false}] [--progress_bar.theme CONFIG] [--progress_bar.theme.description.help CLASS_PATH_OR_NAME] [--progress_bar.theme.description DESCRIPTION] [--progress_bar.theme.progress_bar.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar PROGRESS_BAR] [--progress_bar.theme.progress_bar_finished.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_finished PROGRESS_BAR_FINISHED] [--progress_bar.theme.progress_bar_pulse.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_pulse PROGRESS_BAR_PULSE] [--progress_bar.theme.batch_progress.help CLASS_PATH_OR_NAME] [--progress_bar.theme.batch_progress BATCH_PROGRESS] [--progress_bar.theme.time.help CLASS_PATH_OR_NAME] [--progress_bar.theme.time TIME] [--progress_bar.theme.processing_speed.help CLASS_PATH_OR_NAME] [--progress_bar.theme.processing_speed PROCESSING_SPEED] [--progress_bar.theme.metrics.help CLASS_PATH_OR_NAME] [--progress_bar.theme.metrics METRICS] [--progress_bar.console_kwargs CONSOLE_KWARGS] [--early_stopping CONFIG] [--early_stopping.monitor MONITOR] [--early_stopping.min_delta MIN_DELTA] [--early_stopping.patience PATIENCE] [--early_stopping.verbose {true,false}] [--early_stopping.mode MODE] [--early_stopping.strict {true,false}] [--early_stopping.check_finite {true,false}] [--early_stopping.stopping_threshold STOPPING_THRESHOLD] [--early_stopping.divergence_threshold DIVERGENCE_THRESHOLD] [--early_stopping.check_on_train_epoch_end {true,false,null}] [--early_stopping.log_rank_zero_only {true,false}] [--model_checkpoint CONFIG] [--model_checkpoint.dirpath DIRPATH] [--model_checkpoint.filename FILENAME] [--model_checkpoint.monitor MONITOR] [--model_checkpoint.verbose {true,false}] [--model_checkpoint.save_last {true,false,null}] [--model_checkpoint.save_top_k SAVE_TOP_K] [--model_checkpoint.save_weights_only {true,false}] [--model_checkpoint.mode MODE] [--model_checkpoint.auto_insert_metric_name {true,false}] [--model_checkpoint.every_n_train_steps EVERY_N_TRAIN_STEPS] [--model_checkpoint.train_time_interval TRAIN_TIME_INTERVAL] [--model_checkpoint.every_n_epochs EVERY_N_EPOCHS] [--model_checkpoint.save_on_train_epoch_end {true,false,null}] [--learning_rate_monitor CONFIG] [--learning_rate_monitor.logging_interval LOGGING_INTERVAL] [--learning_rate_monitor.log_momentum {true,false}] [--optimizer.help CLASS_PATH_OR_NAME] [--optimizer CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] [--lr_scheduler.help CLASS_PATH_OR_NAME] [--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] NBSSCLI.py: error: Parser key "model.io": Type <class 'models.io.narrow_band.nbio.NBIO'> expects: a class path (str); or a dict with a class_path entry; or a dict with init_args (if class path given previously). Got "None".
Another way:
from models.NBSS import NBSS import torch model=NBSS( # find the arguments from config.yaml. arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk io=..., ... ) # load weights from checkpoint ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt' data = torch.load(ckpt, map_location='cpu') model.on_load_checkpoint(data) model.load_state_dict(data['state_dict']) model.eval() # inference an audio batch_size, num_chn, length = 1, 2, 64000 audio = torch.rand((batch_size, num_chn, length)) preds = model.predict_step(audio) print(preds.shape)
@wendongj Please try this. I think it is the simplest way.
did NBSS model are defined as this? def read_para(file): f = open(file, 'r', encoding='utf-8') data = f.read() return yaml.load(data, Loader=yaml.FullLoader)
config_data = read_para('logs/NBSS/version_8/config.yaml')
model=NBSS( # find the arguments from config.yaml. arch=config_data['model']['arch'], # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=config_data['model']['io'], )
You should instantiate io
and arch
by their class and arguments, e.g. io=models.arch.NBCv2.NBCv2(...)
. Their class path and arguments can be found in the config.yaml
.
model=NBSS( # find the arguments from config.yaml. arch=..., # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=..., ... )
sorry I am in meeting morning and late reply, I can run now successfully, thanks a lot for your patience,
Great. It’s my pleasure.
Hi, deal author, fortunately to see the state of the arts results on the multi-channel speech separation task, and also thanks for opening source the code, due to the detail introduction in readme, I can now train the model with two mic config, I just want to ask a stupid question, if I want to inference using real recorded 2 mic signal, the length is not 4 seconds, how should I do..., sorry to disturb you...