facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
21.09k stars 2.17k forks source link

Example use case of `compute_predictions`? #81

Open chavinlo opened 1 year ago

chavinlo commented 1 year ago

I am trying in the followng way:

model = MusicGen.get_pretrained('small')

attributes, _ = model._prepare_tokens_and_attributes(["sample text"], None)

conditions = attributes
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + null_conditions
tokenized = model.lm.condition_provider.tokenize(conditions)
cfg_conditions = model.lm.condition_provider(tokenized)

condition_tensors = cfg_conditions

wav, sr = torchaudio.load("music.wav")
wav = torchaudio.functional.resample(wav, sr, 32000) #32k is the model sr
wav = wav.mean(dim=0, keepdim=True)
wav = wav.cuda()

with torch.no_grad():
    gen_audio = model.compression_model.encode(wav)

codes, scale = gen_audio

# codes: torch.Size([1, 4, 1500])

codes = torch.cat([codes, codes], dim=0)
encoded_audio = codes

with model.autocast:
    lm_output = model.lm.compute_predictions(
        codes=encoded_audio,
        conditions=[],
        condition_tensors=condition_tensors
    )

logits, logits_mask = lm_output.logits[0:1], lm_output.mask[0:1] 

This is a minified version of my code, but should replicate the exact problem.

logits come out with size: torch.Size([1, 4, 1500, 2048])

At first, most of the values are "normal":

tensor[1, 4, 1500, 2048] f16 n=12288000 (23Mb) x∈[-24.719, 26.031] μ=-2.340 σ=inf NaN! grad SliceBackward0 cuda:0
tensor([[[[-5.2832e-01, -2.6445e+00, -4.7930e+00,  ...,  3.3740e-01,
           -3.5664e+00, -4.5703e+00],
          [-9.8877e-01, -1.4834e+00, -5.2637e-01,  ...,  2.5039e+00,
           -7.3340e-01, -1.5537e+00],
          [-2.2129e+00, -3.0293e+00, -1.3350e+00,  ...,  2.2891e+00,
           -1.7686e+00, -4.6406e+00],
          ...,
          [-5.8438e+00, -5.1992e+00,  1.4587e-01,  ...,  6.2812e+00,
           -1.1548e-01, -6.1797e+00],
          [-4.3945e+00, -1.3457e+00,  1.0908e+00,  ...,  9.5859e+00,
            1.4971e+00, -4.2266e+00],
          [-3.4688e+00, -9.6797e+00, -2.6582e+00,  ...,  5.3438e+00,
           -1.1133e+00, -7.0117e+00]],

         [[-3.9697e-01, -1.0242e+01,  3.6401e-01,  ..., -6.4746e-01,
            3.4180e-01,  4.5166e-01],
          [-3.4863e-01, -1.0094e+01, -1.3008e+00,  ..., -2.4451e-01,
           -7.7734e-01,  6.3135e-01],
          [ 1.1505e-02, -9.4219e+00, -1.5020e+00,  ..., -3.9307e-01,
           -2.7051e-01,  6.4697e-01],
          ...,
          [ 8.2520e-01, -1.1211e+01, -6.3164e+00,  ..., -9.2850e-03,
            1.0059e+00,  2.7148e+00],
          [-6.7773e-01, -1.1391e+01, -2.4316e+00,  ..., -1.1924e+00,
            1.2744e+00,  6.1328e-01],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan]],

         [[ 7.5635e-01,  8.7036e-02, -1.6980e-01,  ..., -1.8740e+00,
           -2.1855e+00, -9.3203e+00],
          [ 4.4824e-01, -1.3262e+00,  9.6252e-02,  ..., -1.6074e+00,
           -1.5811e+00, -1.1305e+01],
          [ 4.8950e-01, -1.3320e+00, -1.7646e+00,  ..., -3.1270e+00,
           -5.1562e-01, -1.0336e+01],
          ...,
          [ 2.0859e+00, -1.9297e+00, -1.4326e+00,  ..., -1.8323e-01,
           -1.8799e-01, -9.5781e+00],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan]],

         [[ 1.5186e+00, -8.0762e-01,  9.5673e-03,  ..., -1.3721e-01,
            1.5283e+00, -4.4214e-01],
          [-3.2266e+00, -2.3523e-01, -1.0516e-01,  ..., -3.2275e-01,
           -2.3022e-01,  7.0312e-01],
          [-4.5898e+00, -2.2180e-01, -1.9617e-01,  ..., -1.7688e-01,
            6.8701e-01, -2.5254e+00],
          ...,
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan]]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SliceBackward0>)

In this tensor for example, there are 12,288 nans.

I am trying to train, so I replace those nans with 0s. But after a backward pass, it starts returning everything with NaN.

Before this, I have also tried another approach by just passing the ConditioningAttributes rather than the pre-processed tensors to compute_predictions. I passed them on conditions. However, this will always return NaNs.

chavinlo commented 1 year ago

*I don't know much about lm training, so excuse me if I am missing something obvious

sakemin commented 1 year ago

I think it is because of the delay pattern of the encodec codebook pattern structure. in compute_predictions, there's pattern.revert_pattern_logits part in it

 # note: we use nans as special token to make it obvious if we feed unexpected logits
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
            logits, float('nan'), keep_only_valid_steps=True
        )
Xiaohao-Liu commented 5 months ago

I think it is because you did not use automatic mixed precision for training, so that it returns all NaNs after the first backward pass.