Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
232 stars 26 forks source link

using one real recordings to inference #15

Closed wendongj closed 1 year ago

wendongj commented 1 year ago

Hi, deal author, fortunately to see the state of the arts results on the multi-channel speech separation task, and also thanks for opening source the code, due to the detail introduction in readme, I can now train the model with two mic config, I just want to ask a stupid question, if I want to inference using real recorded 2 mic signal, the length is not 4 seconds, how should I do..., sorry to disturb you...

quancs commented 1 year ago

Thank you for your attention to our work. To inference a real-recorded audio, you can

wendongj commented 1 year ago

thanks for your reply, when I direct use the "the corresponding command", it has erros: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch31_neg_si_sdr-15.4105.ckpt --trainer.devices=0, did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

quancs commented 1 year ago

did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

yes. you should change the code a little bit first. In NBSS/NBSSCLI.py:

    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_callback=SaveConfigCallback,
        save_config_kwargs={'overwrite': True},
        subclass_mode_data=True,
        run=False, # add this line
    )
    # add following lines
    model = cli.model
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
quancs commented 1 year ago

You can also 1) create a instance of NBSS module by using the arguments in config.yaml, 2) manually load the weights like a normal torch.Module, 3) call the predict_step of the NBSS module.

wendongj commented 1 year ago

did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

yes. you should change the code a little bit first. In NBSS/NBSSCLI.py:

    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_callback=SaveConfigCallback,
        save_config_kwargs={'overwrite': True},
        subclass_mode_data=True,
        run=False, # add this line
    )
    # add following lines
    model = cli.model
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)

thanks, I will try, and, btw, your profile photo is handsome.

quancs commented 1 year ago

your profile photo is handsome.

Thank you ^_^

wendongj commented 1 year ago

your profile photo is handsome.

Thank you ^_^

really sorry to disturb you again at night, I do as follows: first, I change NBSSCLI python file as following: cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_callback=SaveConfigCallback, save_config_kwargs={'overwrite': True}, subclass_mode_data=True, run=False, # add this line )

add following lines

model = cli.model
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)

then I runn the command: python NBSSCLI.py --config=logs/NBSS/version_8/config.yaml --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt --trainer.devices=0,

the errors are: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt, I tried many times and make sure the ckpt file is right, when I remove the check_path from the command, it have errors: NBSSCLI.py: error: 'Configuration check failed :: No action for destination key "ckpt_path" to check its value, seems the ckpt_path is really need, but when I add the ckpt_path to the command, it comes errors.

quancs commented 1 year ago

Sorry, it's my mistake, the given command seems not work in the new version. But, anyway, you can still inference the audio following:

You can also 1) create a instance of NBSS module by using the arguments in config.yaml, 2) manually load the weights like a normal torch.Module, 3) call the predict_step of the NBSS module.

wendongj commented 1 year ago

many thanks, I will try that, ^_^

quancs commented 1 year ago

wait me a few minutes. I will write a short example for this.

wendongj commented 1 year ago

NBSS

oh, I will wait you at my seat until tomorrow, it's really my honor to see your examples and really thanks

quancs commented 1 year ago

I verified that the following code works, and remove any arguments in config.yaml if them cause errors.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)
wendongj commented 1 year ago

I verified that the following code works.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)

really thanks, the outside is raining, do not go home too late,

quancs commented 1 year ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)
wendongj commented 1 year ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)

oh, both ways I will try

wendongj commented 1 year ago

really thanks, the outside is raining, do not go home too late,

BRO, WHO ARE YOU?

I am an engineer and mainly work on speech enhancement and echo cancellation, recently prepare to do multi-channel speech separation, from FT-JNF paper find your team paper.

quancs commented 1 year ago

Sorry, I thought you are one of my friends, and played with me...

wendongj commented 1 year ago

I verified that the following code works, and remove any arguments in config.yaml if them cause errors.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)

when I use the changed code in NBSSCLI.py, it comes the following error, but, in your side, it has no errors, it is strange usage: NBSSCLI.py [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING] [--trainer CONFIG] [--trainer.accelerator.help CLASS_PATH_OR_NAME] [--trainer.accelerator ACCELERATOR] [--trainer.strategy.help CLASS_PATH_OR_NAME] [--trainer.strategy STRATEGY] [--trainer.devices DEVICES] [--trainer.num_nodes NUM_NODES] [--trainer.precision PRECISION] [--trainer.logger.help CLASS_PATH_OR_NAME] [--trainer.logger LOGGER] [--trainer.callbacks.help CLASS_PATH_OR_NAME] [--trainer.callbacks CALLBACKS] [--trainer.fast_dev_run FAST_DEV_RUN] [--trainer.max_epochs MAX_EPOCHS] [--trainer.min_epochs MIN_EPOCHS] [--trainer.max_steps MAX_STEPS] [--trainer.min_steps MIN_STEPS] [--trainer.max_time MAX_TIME] [--trainer.limit_train_batches LIMIT_TRAIN_BATCHES] [--trainer.limit_val_batches LIMIT_VAL_BATCHES] [--trainer.limit_test_batches LIMIT_TEST_BATCHES] [--trainer.limit_predict_batches LIMIT_PREDICT_BATCHES] [--trainer.overfit_batches OVERFIT_BATCHES] [--trainer.val_check_interval VAL_CHECK_INTERVAL] [--trainer.check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH] [--trainer.num_sanity_val_steps NUM_SANITY_VAL_STEPS] [--trainer.log_every_n_steps LOG_EVERY_N_STEPS] [--trainer.enable_checkpointing {true,false,null}] [--trainer.enable_progress_bar {true,false,null}] [--trainer.enable_model_summary {true,false,null}] [--trainer.accumulate_grad_batches ACCUMULATE_GRAD_BATCHES] [--trainer.gradient_clip_val GRADIENT_CLIP_VAL] [--trainer.gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM] [--trainer.deterministic DETERMINISTIC] [--trainer.benchmark {true,false,null}] [--trainer.inference_mode {true,false}] [--trainer.use_distributed_sampler {true,false}] [--trainer.profiler.help CLASS_PATH_OR_NAME] [--trainer.profiler PROFILER] [--trainer.detect_anomaly {true,false}] [--trainer.barebones {true,false}] [--trainer.plugins.help CLASS_PATH_OR_NAME] [--trainer.plugins PLUGINS] [--trainer.sync_batchnorm {true,false}] [--trainer.reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS] [--trainer.default_root_dir DEFAULT_ROOT_DIR] [--model CONFIG] [--model.arch.help CLASS_PATH_OR_NAME] --model.arch ARCH [--model.io.help CLASS_PATH_OR_NAME] --model.io IO [--model.speaker_num SPEAKER_NUM] [--model.ref_channel REF_CHANNEL] [--model.channels CHANNELS] [--model.learning_rate LEARNING_RATE] [--model.optimizer OPTIMIZER] [--model.optimizer_kwargs OPTIMIZER_KWARGS] [--model.lr_scheduler LR_SCHEDULER] [--model.lr_scheduler_kwargs LR_SCHEDULER_KWARGS] [--model.exp_name EXP_NAME] [--model.metrics METRICS] [--data.help CLASS_PATH_OR_NAME] --data CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE [--progress_bar CONFIG] [--progress_bar.refresh_rate REFRESH_RATE] [--progress_bar.leave {true,false}] [--progress_bar.theme CONFIG] [--progress_bar.theme.description.help CLASS_PATH_OR_NAME] [--progress_bar.theme.description DESCRIPTION] [--progress_bar.theme.progress_bar.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar PROGRESS_BAR] [--progress_bar.theme.progress_bar_finished.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_finished PROGRESS_BAR_FINISHED] [--progress_bar.theme.progress_bar_pulse.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_pulse PROGRESS_BAR_PULSE] [--progress_bar.theme.batch_progress.help CLASS_PATH_OR_NAME] [--progress_bar.theme.batch_progress BATCH_PROGRESS] [--progress_bar.theme.time.help CLASS_PATH_OR_NAME] [--progress_bar.theme.time TIME] [--progress_bar.theme.processing_speed.help CLASS_PATH_OR_NAME] [--progress_bar.theme.processing_speed PROCESSING_SPEED] [--progress_bar.theme.metrics.help CLASS_PATH_OR_NAME] [--progress_bar.theme.metrics METRICS] [--progress_bar.console_kwargs CONSOLE_KWARGS] [--early_stopping CONFIG] [--early_stopping.monitor MONITOR] [--early_stopping.min_delta MIN_DELTA] [--early_stopping.patience PATIENCE] [--early_stopping.verbose {true,false}] [--early_stopping.mode MODE] [--early_stopping.strict {true,false}] [--early_stopping.check_finite {true,false}] [--early_stopping.stopping_threshold STOPPING_THRESHOLD] [--early_stopping.divergence_threshold DIVERGENCE_THRESHOLD] [--early_stopping.check_on_train_epoch_end {true,false,null}] [--early_stopping.log_rank_zero_only {true,false}] [--model_checkpoint CONFIG] [--model_checkpoint.dirpath DIRPATH] [--model_checkpoint.filename FILENAME] [--model_checkpoint.monitor MONITOR] [--model_checkpoint.verbose {true,false}] [--model_checkpoint.save_last {true,false,null}] [--model_checkpoint.save_top_k SAVE_TOP_K] [--model_checkpoint.save_weights_only {true,false}] [--model_checkpoint.mode MODE] [--model_checkpoint.auto_insert_metric_name {true,false}] [--model_checkpoint.every_n_train_steps EVERY_N_TRAIN_STEPS] [--model_checkpoint.train_time_interval TRAIN_TIME_INTERVAL] [--model_checkpoint.every_n_epochs EVERY_N_EPOCHS] [--model_checkpoint.save_on_train_epoch_end {true,false,null}] [--learning_rate_monitor CONFIG] [--learning_rate_monitor.logging_interval LOGGING_INTERVAL] [--learning_rate_monitor.log_momentum {true,false}] [--optimizer.help CLASS_PATH_OR_NAME] [--optimizer CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] [--lr_scheduler.help CLASS_PATH_OR_NAME] [--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] NBSSCLI.py: error: Parser key "model.io": Type <class 'models.io.narrow_band.nbio.NBIO'> expects: a class path (str); or a dict with a class_path entry; or a dict with init_args (if class path given previously). Got "None".

quancs commented 1 year ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)

@wendongj Please try this. I think it is the simplest way.

wendongj commented 1 year ago

did NBSS model are defined as this? def read_para(file): f = open(file, 'r', encoding='utf-8') data = f.read() return yaml.load(data, Loader=yaml.FullLoader)

config_data = read_para('logs/NBSS/version_8/config.yaml')

model=NBSS( # find the arguments from config.yaml. arch=config_data['model']['arch'], # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=config_data['model']['io'], )

quancs commented 1 year ago

You should instantiate io and arch by their class and arguments, e.g. io=models.arch.NBCv2.NBCv2(...). Their class path and arguments can be found in the config.yaml.

wendongj commented 1 year ago

model=NBSS( # find the arguments from config.yaml. arch=..., # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=..., ... )

sorry I am in meeting morning and late reply, I can run now successfully, thanks a lot for your patience,

quancs commented 1 year ago

Great. It’s my pleasure.