lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.33k stars 249 forks source link

Audio generation failing at FineTransformer #199

Closed LWprogramming closed 1 year ago

LWprogramming commented 1 year ago

I tried training a model back when the repo was at commit 95e0669dde9c177b807fa6f0a52e4d2e685c47fd and successfully got checkpoints but it crashed when I tried to test the generations. The error message was a hard-to-understand CUDA message:

generating fine:   0%|          | 0/512 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [480,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [480,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
... many repetitions ...
File "/fsx/itsleonwu/audiolm-pytorch-training/audiolm_pytorch/audiolm_pytorch.py", line 1617, in generate
    _, fine_logits = self.transformer.forward_with_cond_scale(
... more stuff ...

I suspect the problem is some bug with coarse transformer's eos handling in the coarse transformer, because the generation crashes specifically when the fine transformer is just about to get started. I printed state and found that the coarse token id had a -1, which I think is the result of applying mask_out_after_eos_id. But it turns out that the first index of the -1 was at timestep 121, quantizer 2 (0, 121, 2) which is not in-between a "full" quantizer step-- I'd expect the first -1 to appear somewhere like (batch_index, timestep, 0). Seems plausible that this is consistent with a CUDA issue (I'm guessing -1 when you expect all the indices to be small-ish nonnegative ints could result in some memory bounds violations).

Going to use this issue to track any updates and what I've tried-- will be using the script in https://github.com/LWprogramming/audiolm-pytorch-training/blob/main/audiolm_pytorch_demo_laion.py (which I set up to eliminate non-determinism).

lucidrains commented 1 year ago

@LWprogramming oh thanks! i think this should be an easy fix

weird that @jinyuli didn't run into this error

do you have a small script that repros this error?

LWprogramming commented 1 year ago

Here is the script I'm trying on some small dataset but you can use an artificial one by uncommenting make_placeholder_dataset(), switching to dataset_folder = f"{prefix}/placeholder_dataset", and setting all the train steps etc to be super low so you can get to the error quickly.

And I confirmed that the eos is the problem because the assertion here triggered when my job ran last night

lucidrains commented 1 year ago

@LWprogramming yes, i do believe there was logic i misplaced into the FineTransformerWrapper, which AudioLM does not have access to

can you test out 1.1.6 and see if that works out?

LWprogramming commented 1 year ago

Just submitted the job to try (pending, so no results yet) but while we wait, just to check my understanding: this change masks out anything that isn't an actual coarse index, so the transformer doesn't learn anything that relies on those special tokens. However, why does this prevent eos from appearing in the wrong spot (i.e. not aligned with the end of a quantizer step)? Or is the goal just to make that really low probability because during training attention never sees eos?

lucidrains commented 1 year ago

i'll have to take a closer look later this weekend since i just got back from my trip and still readjusting / running errands, but i don't believe the eos token is even kept https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1392 https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L113

LWprogramming commented 1 year ago

Right, everything after eos should disappear based on that masking logic, although I'm a bit confused the relation between this masking and the fine transformer logic you implemented in the FineTransformer change. I think the original issue was that there was an eos token in coarse_token_ids even though it should've been masked out by the code you link. This issue only shows up when we actually try to use these coarse_token_ids in FineTransformer's generate(), which (iiuc) expects the eos to have been masked out correctly.

Oh, in the process of writing this I think I get what you did here? So if previously the FineTransformerWrapper had logic to avoid trying to do anything with eos coarse tokens that weren't properly masked, you moved that to FineTransformer so it always works. But if the eos should be masked out in CoarseTransformer already, I'm not sure why we see eos by the time we get to FineTransformer anyways.

hope the trip was nice! :)

lucidrains commented 1 year ago

@LWprogramming that's true! well, it wouldn't hurt to keep it in there for now :smile:

thanks, it was nice!

lucidrains commented 1 year ago

should be resolved, feel free to reopen if any new error pops up!

LWprogramming commented 1 year ago

Hm, it seems to work now when I try a different dataset. I'd originally tried to train the model on a tiny dataset (intentionally overfit to see if it can do that) with samples trimmed to exactly data_max_length, and that's when unaligned eos starts showing up. It still does, but I can just try on input data that's a bit larger and that should probably be ok.

edit: hang on, I didn't do the trimming properly. Now I'm not sure what's causing the issue 🙃

lucidrains commented 1 year ago

@LWprogramming yea, i'm confused by your assertions because i remove the [eos] anyways, so i'm not sure what you are fetching https://github.com/LWprogramming/audiolm-pytorch/commit/6b82fc18bbcc839edf83aa13eae55cacb862b034#diff-96a5ee045c1df07f3125d9b4189130620f229785b36cebb86c95b0646f0d744dR1430

LWprogramming commented 1 year ago

In that code (which I updated to this, which doesn't error out if no eos occurs), I was finding that for some reason sampled_coarse_token_ids did in fact contain eos ids for some reason. In the particular case that I observed, I saw it at timestep 121, quantizer 2 (when num_coarse_quantizers is 3). In that case, is_last_step is True and it's possible that we get eos there.

lucidrains commented 1 year ago

that is so strange; I just checked the 'mask after eos id function' and could not spot any obvious bugs

https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1388

lucidrains commented 1 year ago

I'm glad you pushed on this though, because of the off-by-one eos bug you uncovered

LWprogramming commented 1 year ago

that is so strange; I just checked the 'mask after eos id function' and could not spot any obvious bugs

https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1388

I think the function itself is fine, but the input to the function isn't. If sampled_coarse_token_ids eventually becomes num_coarse_quantizers x num_timesteps shape:

q0,0 q0,1 q0,2 ... q121,0
q1,0, q1,1 q1,2 ... q121,1
q2,0 q2,1 q2,2 ... q121,2

where the notation is q_{timestep},{quantizer number}, the function input sampled_coarse_token_ids is actually a kind-of-flattened tensor where each element in the batch is a 1-D tensor: q0,0 q1,0 q2,0 q0,1 q1,1 ... q121,2 .... In this case the masking code sees the first eos at q121,2 and turns that (and everything after it) with the masking value.

lucidrains commented 1 year ago

ohh ok, you were checking the eos position before this function was called

this is fortuitous, as the masking would have hid this unaligned eos error