Closed LWprogramming closed 1 year ago
@LWprogramming oh thanks! i think this should be an easy fix
weird that @jinyuli didn't run into this error
do you have a small script that repros this error?
Here is the script I'm trying on some small dataset but you can use an artificial one by uncommenting make_placeholder_dataset()
, switching to dataset_folder = f"{prefix}/placeholder_dataset"
, and setting all the train steps etc to be super low so you can get to the error quickly.
And I confirmed that the eos is the problem because the assertion here triggered when my job ran last night
@LWprogramming yes, i do believe there was logic i misplaced into the FineTransformerWrapper
, which AudioLM
does not have access to
can you test out 1.1.6 and see if that works out?
Just submitted the job to try (pending, so no results yet) but while we wait, just to check my understanding: this change masks out anything that isn't an actual coarse index, so the transformer doesn't learn anything that relies on those special tokens. However, why does this prevent eos from appearing in the wrong spot (i.e. not aligned with the end of a quantizer step)? Or is the goal just to make that really low probability because during training attention never sees eos?
i'll have to take a closer look later this weekend since i just got back from my trip and still readjusting / running errands, but i don't believe the eos token is even kept https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1392 https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L113
Right, everything after eos should disappear based on that masking logic, although I'm a bit confused the relation between this masking and the fine transformer logic you implemented in the FineTransformer
change. I think the original issue was that there was an eos token in coarse_token_ids
even though it should've been masked out by the code you link. This issue only shows up when we actually try to use these coarse_token_ids
in FineTransformer
's generate()
, which (iiuc) expects the eos to have been masked out correctly.
Oh, in the process of writing this I think I get what you did here? So if previously the FineTransformerWrapper
had logic to avoid trying to do anything with eos coarse tokens that weren't properly masked, you moved that to FineTransformer
so it always works. But if the eos should be masked out in CoarseTransformer
already, I'm not sure why we see eos by the time we get to FineTransformer
anyways.
hope the trip was nice! :)
@LWprogramming that's true! well, it wouldn't hurt to keep it in there for now :smile:
thanks, it was nice!
should be resolved, feel free to reopen if any new error pops up!
Hm, it seems to work now when I try a different dataset. I'd originally tried to train the model on a tiny dataset (intentionally overfit to see if it can do that) with samples trimmed to exactly data_max_length
, and that's when unaligned eos starts showing up. It still does, but I can just try on input data that's a bit larger and that should probably be ok.
edit: hang on, I didn't do the trimming properly. Now I'm not sure what's causing the issue 🙃
@LWprogramming yea, i'm confused by your assertions because i remove the [eos] anyways, so i'm not sure what you are fetching https://github.com/LWprogramming/audiolm-pytorch/commit/6b82fc18bbcc839edf83aa13eae55cacb862b034#diff-96a5ee045c1df07f3125d9b4189130620f229785b36cebb86c95b0646f0d744dR1430
In that code (which I updated to this, which doesn't error out if no eos occurs), I was finding that for some reason sampled_coarse_token_ids
did in fact contain eos ids for some reason. In the particular case that I observed, I saw it at timestep 121, quantizer 2 (when num_coarse_quantizers
is 3). In that case, is_last_step
is True
and it's possible that we get eos there.
that is so strange; I just checked the 'mask after eos id function' and could not spot any obvious bugs
https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1388
I'm glad you pushed on this though, because of the off-by-one eos bug you uncovered
that is so strange; I just checked the 'mask after eos id function' and could not spot any obvious bugs
https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L1388
I think the function itself is fine, but the input to the function isn't. If sampled_coarse_token_ids
eventually becomes num_coarse_quantizers x num_timesteps
shape:
q0,0 q0,1 q0,2 ... q121,0
q1,0, q1,1 q1,2 ... q121,1
q2,0 q2,1 q2,2 ... q121,2
where the notation is q_{timestep},{quantizer number}
, the function input sampled_coarse_token_ids
is actually a kind-of-flattened tensor where each element in the batch is a 1-D tensor: q0,0 q1,0 q2,0 q0,1 q1,1 ... q121,2 ...
. In this case the masking code sees the first eos at q121,2
and turns that (and everything after it) with the masking value.
ohh ok, you were checking the eos position before this function was called
this is fortuitous, as the masking would have hid this unaligned eos error
I tried training a model back when the repo was at commit 95e0669dde9c177b807fa6f0a52e4d2e685c47fd and successfully got checkpoints but it crashed when I tried to test the generations. The error message was a hard-to-understand CUDA message:
I suspect the problem is some bug with coarse transformer's eos handling in the coarse transformer, because the generation crashes specifically when the fine transformer is just about to get started. I printed state and found that the coarse token id had a -1, which I think is the result of applying
mask_out_after_eos_id
. But it turns out that the first index of the -1 was at timestep 121, quantizer 2(0, 121, 2)
which is not in-between a "full" quantizer step-- I'd expect the first -1 to appear somewhere like(batch_index, timestep, 0)
. Seems plausible that this is consistent with a CUDA issue (I'm guessing -1 when you expect all the indices to be small-ish nonnegative ints could result in some memory bounds violations).Going to use this issue to track any updates and what I've tried-- will be using the script in https://github.com/LWprogramming/audiolm-pytorch-training/blob/main/audiolm_pytorch_demo_laion.py (which I set up to eliminate non-determinism).