Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
175 stars 21 forks source link

using one real recordings to inference #15

Closed wendongj closed 11 months ago

wendongj commented 11 months ago

Hi, deal author, fortunately to see the state of the arts results on the multi-channel speech separation task, and also thanks for opening source the code, due to the detail introduction in readme, I can now train the model with two mic config, I just want to ask a stupid question, if I want to inference using real recorded 2 mic signal, the length is not 4 seconds, how should I do..., sorry to disturb you...

quancs commented 11 months ago

Thank you for your attention to our work. To inference a real-recorded audio, you can

wendongj commented 11 months ago

thanks for your reply, when I direct use the "the corresponding command", it has erros: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch31_neg_si_sdr-15.4105.ckpt --trainer.devices=0, did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

quancs commented 11 months ago

did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

yes. you should change the code a little bit first. In NBSS/NBSSCLI.py:

    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_callback=SaveConfigCallback,
        save_config_kwargs={'overwrite': True},
        subclass_mode_data=True,
        run=False, # add this line
    )
    # add following lines
    model = cli.model
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
quancs commented 11 months ago

You can also 1) create a instance of NBSS module by using the arguments in config.yaml, 2) manually load the weights like a normal torch.Module, 3) call the predict_step of the NBSS module.

wendongj commented 11 months ago

did I need to do this "give run=False to NBSSCLI: " step before run the command? but I did not find how to give run=False,

yes. you should change the code a little bit first. In NBSS/NBSSCLI.py:

    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_callback=SaveConfigCallback,
        save_config_kwargs={'overwrite': True},
        subclass_mode_data=True,
        run=False, # add this line
    )
    # add following lines
    model = cli.model
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)

thanks, I will try, and, btw, your profile photo is handsome.

quancs commented 11 months ago

your profile photo is handsome.

Thank you ^_^

wendongj commented 11 months ago

your profile photo is handsome.

Thank you ^_^

really sorry to disturb you again at night, I do as follows: first, I change NBSSCLI python file as following: cli = NBSSCLI( NBSS, pl.LightningDataModule, save_config_callback=SaveConfigCallback, save_config_kwargs={'overwrite': True}, subclass_mode_data=True, run=False, # add this line )

add following lines

model = cli.model
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)

then I runn the command: python NBSSCLI.py --config=logs/NBSS/version_8/config.yaml --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt --trainer.devices=0,

the errors are: NBSSCLI.py: error: Unrecognized arguments: --ckpt_path=logs/NBSS/version_8/checkpoints/epoch32_neg_si_sdr-15.3829.ckpt, I tried many times and make sure the ckpt file is right, when I remove the check_path from the command, it have errors: NBSSCLI.py: error: 'Configuration check failed :: No action for destination key "ckpt_path" to check its value, seems the ckpt_path is really need, but when I add the ckpt_path to the command, it comes errors.

quancs commented 11 months ago

Sorry, it's my mistake, the given command seems not work in the new version. But, anyway, you can still inference the audio following:

You can also 1) create a instance of NBSS module by using the arguments in config.yaml, 2) manually load the weights like a normal torch.Module, 3) call the predict_step of the NBSS module.

wendongj commented 11 months ago

many thanks, I will try that, ^_^

quancs commented 11 months ago

wait me a few minutes. I will write a short example for this.

wendongj commented 11 months ago

NBSS

oh, I will wait you at my seat until tomorrow, it's really my honor to see your examples and really thanks

quancs commented 11 months ago

I verified that the following code works, and remove any arguments in config.yaml if them cause errors.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)
wendongj commented 11 months ago

I verified that the following code works.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)

really thanks, the outside is raining, do not go home too late,

quancs commented 11 months ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)
wendongj commented 11 months ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)

oh, both ways I will try

wendongj commented 11 months ago

really thanks, the outside is raining, do not go home too late,

BRO, WHO ARE YOU?

I am an engineer and mainly work on speech enhancement and echo cancellation, recently prepare to do multi-channel speech separation, from FT-JNF paper find your team paper.

quancs commented 11 months ago

Sorry, I thought you are one of my friends, and played with me...

wendongj commented 11 months ago

I verified that the following code works, and remove any arguments in config.yaml if them cause errors.

# CMD:  python NBSSCLI.py --config logs/NBSS/version_x/config.yaml
if __name__ == '__main__':
    cli = NBSSCLI(
        NBSS,
        pl.LightningDataModule,
        save_config_kwargs={'overwrite': True},
        save_config_callback=SaveConfigCallback,
        subclass_mode_data=True,
        run=False,  # add this line
    )
    model = cli.model

    # load weights from checkpoint
    ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
    data = torch.load(ckpt, map_location='cpu')
    model.on_load_checkpoint(data)
    model.load_state_dict(data['state_dict'])
    model.eval()

    # inference an audio
    batch_size, num_chn, length = 1, 2, 64000
    audio = torch.rand((batch_size, num_chn, length))
    preds = model.predict_step(audio)
    print(preds.shape)

when I use the changed code in NBSSCLI.py, it comes the following error, but, in your side, it has no errors, it is strange usage: NBSSCLI.py [-h] [-c CONFIG] [--print_config[=flags]] [--seed_everything SEED_EVERYTHING] [--trainer CONFIG] [--trainer.accelerator.help CLASS_PATH_OR_NAME] [--trainer.accelerator ACCELERATOR] [--trainer.strategy.help CLASS_PATH_OR_NAME] [--trainer.strategy STRATEGY] [--trainer.devices DEVICES] [--trainer.num_nodes NUM_NODES] [--trainer.precision PRECISION] [--trainer.logger.help CLASS_PATH_OR_NAME] [--trainer.logger LOGGER] [--trainer.callbacks.help CLASS_PATH_OR_NAME] [--trainer.callbacks CALLBACKS] [--trainer.fast_dev_run FAST_DEV_RUN] [--trainer.max_epochs MAX_EPOCHS] [--trainer.min_epochs MIN_EPOCHS] [--trainer.max_steps MAX_STEPS] [--trainer.min_steps MIN_STEPS] [--trainer.max_time MAX_TIME] [--trainer.limit_train_batches LIMIT_TRAIN_BATCHES] [--trainer.limit_val_batches LIMIT_VAL_BATCHES] [--trainer.limit_test_batches LIMIT_TEST_BATCHES] [--trainer.limit_predict_batches LIMIT_PREDICT_BATCHES] [--trainer.overfit_batches OVERFIT_BATCHES] [--trainer.val_check_interval VAL_CHECK_INTERVAL] [--trainer.check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH] [--trainer.num_sanity_val_steps NUM_SANITY_VAL_STEPS] [--trainer.log_every_n_steps LOG_EVERY_N_STEPS] [--trainer.enable_checkpointing {true,false,null}] [--trainer.enable_progress_bar {true,false,null}] [--trainer.enable_model_summary {true,false,null}] [--trainer.accumulate_grad_batches ACCUMULATE_GRAD_BATCHES] [--trainer.gradient_clip_val GRADIENT_CLIP_VAL] [--trainer.gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM] [--trainer.deterministic DETERMINISTIC] [--trainer.benchmark {true,false,null}] [--trainer.inference_mode {true,false}] [--trainer.use_distributed_sampler {true,false}] [--trainer.profiler.help CLASS_PATH_OR_NAME] [--trainer.profiler PROFILER] [--trainer.detect_anomaly {true,false}] [--trainer.barebones {true,false}] [--trainer.plugins.help CLASS_PATH_OR_NAME] [--trainer.plugins PLUGINS] [--trainer.sync_batchnorm {true,false}] [--trainer.reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS] [--trainer.default_root_dir DEFAULT_ROOT_DIR] [--model CONFIG] [--model.arch.help CLASS_PATH_OR_NAME] --model.arch ARCH [--model.io.help CLASS_PATH_OR_NAME] --model.io IO [--model.speaker_num SPEAKER_NUM] [--model.ref_channel REF_CHANNEL] [--model.channels CHANNELS] [--model.learning_rate LEARNING_RATE] [--model.optimizer OPTIMIZER] [--model.optimizer_kwargs OPTIMIZER_KWARGS] [--model.lr_scheduler LR_SCHEDULER] [--model.lr_scheduler_kwargs LR_SCHEDULER_KWARGS] [--model.exp_name EXP_NAME] [--model.metrics METRICS] [--data.help CLASS_PATH_OR_NAME] --data CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE [--progress_bar CONFIG] [--progress_bar.refresh_rate REFRESH_RATE] [--progress_bar.leave {true,false}] [--progress_bar.theme CONFIG] [--progress_bar.theme.description.help CLASS_PATH_OR_NAME] [--progress_bar.theme.description DESCRIPTION] [--progress_bar.theme.progress_bar.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar PROGRESS_BAR] [--progress_bar.theme.progress_bar_finished.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_finished PROGRESS_BAR_FINISHED] [--progress_bar.theme.progress_bar_pulse.help CLASS_PATH_OR_NAME] [--progress_bar.theme.progress_bar_pulse PROGRESS_BAR_PULSE] [--progress_bar.theme.batch_progress.help CLASS_PATH_OR_NAME] [--progress_bar.theme.batch_progress BATCH_PROGRESS] [--progress_bar.theme.time.help CLASS_PATH_OR_NAME] [--progress_bar.theme.time TIME] [--progress_bar.theme.processing_speed.help CLASS_PATH_OR_NAME] [--progress_bar.theme.processing_speed PROCESSING_SPEED] [--progress_bar.theme.metrics.help CLASS_PATH_OR_NAME] [--progress_bar.theme.metrics METRICS] [--progress_bar.console_kwargs CONSOLE_KWARGS] [--early_stopping CONFIG] [--early_stopping.monitor MONITOR] [--early_stopping.min_delta MIN_DELTA] [--early_stopping.patience PATIENCE] [--early_stopping.verbose {true,false}] [--early_stopping.mode MODE] [--early_stopping.strict {true,false}] [--early_stopping.check_finite {true,false}] [--early_stopping.stopping_threshold STOPPING_THRESHOLD] [--early_stopping.divergence_threshold DIVERGENCE_THRESHOLD] [--early_stopping.check_on_train_epoch_end {true,false,null}] [--early_stopping.log_rank_zero_only {true,false}] [--model_checkpoint CONFIG] [--model_checkpoint.dirpath DIRPATH] [--model_checkpoint.filename FILENAME] [--model_checkpoint.monitor MONITOR] [--model_checkpoint.verbose {true,false}] [--model_checkpoint.save_last {true,false,null}] [--model_checkpoint.save_top_k SAVE_TOP_K] [--model_checkpoint.save_weights_only {true,false}] [--model_checkpoint.mode MODE] [--model_checkpoint.auto_insert_metric_name {true,false}] [--model_checkpoint.every_n_train_steps EVERY_N_TRAIN_STEPS] [--model_checkpoint.train_time_interval TRAIN_TIME_INTERVAL] [--model_checkpoint.every_n_epochs EVERY_N_EPOCHS] [--model_checkpoint.save_on_train_epoch_end {true,false,null}] [--learning_rate_monitor CONFIG] [--learning_rate_monitor.logging_interval LOGGING_INTERVAL] [--learning_rate_monitor.log_momentum {true,false}] [--optimizer.help CLASS_PATH_OR_NAME] [--optimizer CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] [--lr_scheduler.help CLASS_PATH_OR_NAME] [--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE] NBSSCLI.py: error: Parser key "model.io": Type <class 'models.io.narrow_band.nbio.NBIO'> expects: a class path (str); or a dict with a class_path entry; or a dict with init_args (if class path given previously). Got "None".

quancs commented 11 months ago

Another way:

from models.NBSS import NBSS
import torch

model=NBSS( # find the arguments from config.yaml.
    arch=..., # the input and output size for arch: input_size=2*num_channels; output_size=2*num_spk
    io=...,
    ...
)

# load weights from checkpoint
ckpt = 'logs/NBSS/version_293/checkpoints/last.ckpt'
data = torch.load(ckpt, map_location='cpu')
model.on_load_checkpoint(data)
model.load_state_dict(data['state_dict'])
model.eval()

# inference an audio
batch_size, num_chn, length = 1, 2, 64000
audio = torch.rand((batch_size, num_chn, length))
preds = model.predict_step(audio)
print(preds.shape)

@wendongj Please try this. I think it is the simplest way.

wendongj commented 11 months ago

did NBSS model are defined as this? def read_para(file): f = open(file, 'r', encoding='utf-8') data = f.read() return yaml.load(data, Loader=yaml.FullLoader)

config_data = read_para('logs/NBSS/version_8/config.yaml')

model=NBSS( # find the arguments from config.yaml. arch=config_data['model']['arch'], # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=config_data['model']['io'], )

quancs commented 11 months ago

You should instantiate io and arch by their class and arguments, e.g. io=models.arch.NBCv2.NBCv2(...). Their class path and arguments can be found in the config.yaml.

wendongj commented 11 months ago

model=NBSS( # find the arguments from config.yaml. arch=..., # the input and output size for arch: input_size=2num_channels; output_size=2num_spk io=..., ... )

sorry I am in meeting morning and late reply, I can run now successfully, thanks a lot for your patience,

quancs commented 11 months ago

Great. It’s my pleasure.