ylacombe / musicgen-dreamboothing

Fine-tune your own MusicGen with LoRA
Apache License 2.0
103 stars 11 forks source link

DataCollatorMusicGenWithPadding might have a bug #4

Open LiuZH-19 opened 6 months ago

LiuZH-19 commented 6 months ago

Thank you very much for your amazing work! While using melody-conditioned generation, I encountered the following error in the DataCollatorMusicGenWithPadding class:

batch[self.feature_extractor_input_name : input_values]
TypeError: unhashable type: 'slice'

input_values here is actually a dictionary. I resolved the error by changing the code to batch.update(input_values). Could you kindly confirm if this approach is correct?

def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        labels = [
            torch.tensor(feature["labels"]).transpose(0, 1) for feature in features
        ]
        # (bsz, seq_len, num_codebooks)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100
        )

        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        input_ids = self.processor.tokenizer.pad(input_ids, return_tensors="pt")

        batch = {"labels": labels, **input_ids}

        if self.feature_extractor_input_name in features[0]:
            input_values = [
                {
                    self.feature_extractor_input_name: feature[
                        self.feature_extractor_input_name
                    ]
                }
                for feature in features
            ]
            input_values = self.processor.feature_extractor.pad(
                input_values, return_tensors="pt"
            )
            batch.update(input_values)
            # batch[self.feature_extractor_input_name : input_values]

        return batch
hieuhthh commented 4 months ago

I meet the same issue, can you fix it. And how to inference with this audio and prompt together? Thank you!