lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

have trouble to generate semantic tokens using the demo code #235

Closed dwangF0 closed 12 months ago

dwangF0 commented 12 months ago

I am wondering if anyone is able to successfully run the following code in the demo without making a change? I mean the demo for generating semantic tokens given text information, here is the link.

run sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2)

I can run the training part before the above line succesfully, but not inference as the above line.

I would like to verify several parts:

1. line 1423 of ./audiolm_pytorch/audiolm_pytorch.py

for ind in tqdm(range(start_length, max_length), desc = 'generating semantic')

Does "ind" in the above line indicate each new estimated semantic token?

2. line 456 - line 463 of ./audiolm_pytorch/audiolm_pytorch.py

    if exists(kv_cache):
        cache_len = kv_cache.shape[-2]
        kv_cache = iter(kv_cache)
    else:
        cache_len = 0
        kv_cache = iter([])

    x = x[:, cache_len:]

When I running the code, I printed the shape of kv_cache and x, as below:

kv_cache shape when exists: torch.Size([6, 2, 1, 18, 64]) x shape after grad_shrink: torch.Size([1, 2, 1024])

You can see that kv_cache.shape[-2] = 18 is much larger than x.shape[1]. So I am wondering how is it possible to use kv_cache.shape[-2] as the index for x in the 2nd dimension.

3, line 122 of ./audiolm_pytorch/attend.py

sim = sim + attn_bias

With original code, the datasize of sim and attn_bias always mismatches after ind > 0. Therefore, I had to add a step to pad attn_bias and mask after each iteration of ind (from start_length to max_length). I am wondering if my logic is correct or not.

Looking forward to any insights regarding these problems.

Thanks!

lucidrains commented 12 months ago

@dwangF0 ah yea, the reason is because the new kv cache for faster inference is not compatible with the vall-e way of conditioning (with the condition as a prefix)

i've turned it off for now

dwangF0 commented 12 months ago

Thank you so much! It works now.