facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
20.15k stars 2.01k forks source link

MusicGen: Missing 'rtf' assignment in prompted samples generation. #458

Open ODD2 opened 1 month ago

ODD2 commented 1 month ago

Hi, I'm currently experimenting on MusicGen and encountered a 'rtf undefined' exception when the system is configured to only generate prompted samples, that is the 'generate' section in config/solver/musigen/default.yaml is modified as:

generate:
  every: 25
  num_workers: 5
  path: samples
  audio:
    format: wav
    strategy: loudness
    sample_rate: ${sample_rate}
    loudness_headroom_db: 14
  lm:
    prompted_samples: true 
    unprompted_samples: false # <- this line is modified
    gen_gt_samples: false
    prompt_duration: null   # if not set, will use dataset.generate.segment_duration / 4
    gen_duration: null      # if not set, will use dataset.generate.segment_duration
    remove_prompts: false
    # generation params
    use_sampling: false
    temp: 1.0
    top_k: 0
    top_p: 0.0

I'm guessing the assignment of 'rtf' is missing in the section for prompted sample generation?

#line 577 in audiocraft/solvers/musicgen
if self.cfg.generate.lm.prompted_samples:
    gen_outputs = self.run_generate_step(
        batch, gen_duration=target_duration, prompt_duration=prompt_duration,
        **self.generation_params)
    gen_audio = gen_outputs['gen_audio'].cpu()
    prompt_audio = gen_outputs['prompt_audio'].cpu()
    sample_manager.add_samples(
        gen_audio, self.epoch, hydrated_conditions,
        prompt_wavs=prompt_audio, ground_truth_wavs=audio,
        generation_args=sample_generation_params)
   # rtf = gen_outputs["rtf"])  missing?

Currently, I've modified the generation section for the 'rtf' metric as follow:

#line 560 in audiocraft/solvers/musicgen
rtf = [] # <- modified 
if self.cfg.generate.lm.unprompted_samples:
    if self.cfg.generate.lm.gen_gt_samples:
        # get the ground truth instead of generation
        self.logger.warn(
            "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
        gen_unprompted_audio = audio
        rtf.append(1.)  # <- modified 
    else:
        gen_unprompted_outputs = self.run_generate_step(
            batch, gen_duration=target_duration, prompt_duration=None,
            **self.generation_params)
        gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
        rtf.append(gen_unprompted_outputs['rtf'])  # <- modified 
    sample_manager.add_samples(
        gen_unprompted_audio, self.epoch, hydrated_conditions,
        ground_truth_wavs=audio, generation_args=sample_generation_params)

if self.cfg.generate.lm.prompted_samples:
    gen_outputs = self.run_generate_step(
        batch, gen_duration=target_duration, prompt_duration=prompt_duration,
        **self.generation_params)
    gen_audio = gen_outputs['gen_audio'].cpu()
    prompt_audio = gen_outputs['prompt_audio'].cpu()
    sample_manager.add_samples(
        gen_audio, self.epoch, hydrated_conditions,
        prompt_wavs=prompt_audio, ground_truth_wavs=audio,
        generation_args=sample_generation_params)
    rtf.append(gen_outputs["rtf"]) # <- modified 

metrics['rtf'] = sum(rtf)/min(len(rtf),1) # <- modified 

Please let me know if the modification is correct. Thanks for the great work.