ylacombe / musicgen-dreamboothing

Fine-tune your own MusicGen with LoRA
Apache License 2.0
109 stars 12 forks source link

Incorporating input_values for Audio-Prompted Audio Generation in Musicgen Model Training #5

Open LiuZH-19 opened 6 months ago

LiuZH-19 commented 6 months ago

I want to train the musicgen model (instead musicgen melody model) for Audio-Prompted audio continuation/generation tasks. According to my interpretation of the code provided below, it appears that input_values (i.e., audio prompt) are not utilized during the training phase. https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/modeling_musicgen.py#L2254

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): #trainning
    decoder_input_ids = shift_tokens_right(
        labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
    )

elif decoder_input_ids is None and decoder_inputs_embeds is None:  # When to use?
    audio_encoder_outputs = self.audio_encoder(
        input_values=input_values,
        padding_mask=padding_mask,
        **kwargs_audio_encoder,
    )
    audio_codes = audio_encoder_outputs.audio_codes
    frames, bsz, codebooks, seq_len = audio_codes.shape
    if frames != 1:
        raise ValueError(
            f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
            "disabled by setting `chunk_length=None` in the audio encoder."
        )

    if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
        # mono input through encodec that we convert to stereo
        audio_codes = audio_codes.repeat_interleave(2, dim=2)

    decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)

If I intend to incorporate input_values into the training process, what modifications should I make? Your guidance on this matter would be greatly appreciated. Thank you for your assistance!

hieuhthh commented 5 months ago

Do you have any solution for this