facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
20.15k stars 2.01k forks source link

Incorrect tensor dimension in method `LMModel.forward()` docstring #453

Open Daniel-Chin opened 2 months ago

Daniel-Chin commented 2 months ago

There seems to be a mistake in the docstring of audiocraft.models.lm.LMModel.forward().

At https://github.com/facebookresearch/audiocraft/blob/87af0bfddd489c5f22d2ea6743fb4afbe092539e/audiocraft/models/lm.py#L226:

    def forward(self, sequence: torch.Tensor,
                conditions: tp.List[ConditioningAttributes],
                condition_tensors: tp.Optional[ConditionTensors] = None,
                stage: int = -1) -> torch.Tensor:
        """Apply language model on sequence and conditions.
        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
        S the sequence steps, return the logits with shape [B, card, K, S]. ..."""

it claims the output is logits with shape [B, card, K, S] whereas in reality it returns shape [B, K, S, card].

To reproduce lm.forward returning shape [B, K, S, card],

import math

from audiocraft.models.musicgen import MusicGen
from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout

musicGen = MusicGen.get_pretrained('facebook/musicgen-small', device='cuda')
musicGen.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=30
)

TEMP_SR = 32000
def get_bip_bip(bip_duration=0.125, frequency=440,
                duration=0.5, sample_rate=TEMP_SR, device="cuda"):
    """Generates a series of bip bip at the given frequency."""
    t = torch.arange(
        int(duration * sample_rate), device="cuda", dtype=torch.float) / sample_rate
    wav = torch.cos(2 * math.pi * 440 * t)[None]
    tp = (t % (2 * bip_duration)) / (2 * bip_duration)
    envelope = (tp >= 0.5).float()
    return wav * envelope
bipbip = get_bip_bip()
prompt = bipbip.expand(1, -1, -1).cuda()

attributes, prompt_tokens = musicGen._prepare_tokens_and_attributes([None], prompt)

with musicGen.autocast:
    lm = musicGen.lm
    null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(attributes)
    conditions = attributes + null_conditions
    tokenized = lm.condition_provider.tokenize(conditions)
    cfg_conditions = lm.condition_provider(tokenized)

    B, K, T = codes.shape
    start_offset = T
    unknown_token = -1
    max_gen_len = 1500
    gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device='cuda')
    gen_codes[..., :start_offset] = codes
    pattern = lm.pattern_provider.get_pattern(max_gen_len)
    gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, lm.special_token_id)

    start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
    with lm.streaming():
        curr_sequence = gen_sequence[..., :start_offset_sequence]
        curr_mask = mask[None, ..., :start_offset_sequence].expand(B, -1, -1)

        # check coherence between mask and sequence
        assert (curr_sequence == torch.where(curr_mask, curr_sequence, lm.special_token_id)).all()
        # should never happen as gen_sequence is filled progressively
        assert not (curr_sequence == unknown_token).any()

        db_sequence = torch.cat([curr_sequence, curr_sequence], dim=0)
        print(db_sequence.shape, db_sequence.dtype)
        out = lm.forward(db_sequence, [], condition_tensors=cfg_conditions)
        print(out.shape)
Daniel-Chin commented 2 months ago

Later in the same function we see the correct shape:

https://github.com/facebookresearch/audiocraft/blob/87af0bfddd489c5f22d2ea6743fb4afbe092539e/audiocraft/models/lm.py#L267