lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

[BUG] Residual VQ - self.embed.data[ind][mask] = sampled - RuntimeError: shape mismatch: #142

Open dwromero opened 6 days ago

dwromero commented 6 days ago

Hi all,

I noticed that using ResidualVQ as:

ResidualVQ(
            dim=Z_CHANNELS,  # 512
            num_quantizers=NUM_QUANTIZERS,  # 2
            codebook_size=CODEBOOK_SIZE,  # 16 * 1024
            stochastic_sample_codes=True,
            shared_codebook=True,
            commitment_weight=1.0,
            kmeans_init=True,
            threshold_ema_dead_code=2,
            quantize_dropout=True,
            quantize_dropout_cutoff_index=1,
            quantize_dropout_multiple_of=1,
        )

Leads to the following error:

File "/mnt/workspace/Projects/autoencoder.py", line 261, in forward
    z_tilde, _, commit_loss = self.vq(z)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_vq.py", line 183, in forward
    quantized, *rest = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 919, in forward
    quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 514, in forward
    self.expire_codes_(x)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 443, in expire_codes_
    self.replace(batch_samples, batch_mask = expired_codes)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 428, in replace
    self.embed.data[ind][mask] = sampled
RuntimeError: shape mismatch: value tensor of shape [9330, 512] cannot be broadcast to indexing result of shape [9331, 512]

This happens randomly during training (in a multinode setting). Any idea what the cause could be?

lucidrains commented 5 days ago

@dwromero hey David! want to try setting this to False for now and see if that resolves your issue?

dwromero commented 5 days ago

Hi @lucidrains , thank you for your fast reply. I'll try it out now.

David

dwromero commented 4 days ago

@lucidrains, it works now! I do not know if this is a full solution to the problem though. Please let me know if you feel this is the case and I can close the issue.

lucidrains commented 4 days ago

@dwromero nice! at least it doesn't block you from your research now!

if you'd like to help me get to the bottom of this, you could turn it back to True in 1.14.40 and share with me the stack trace once it errors again

lucidrains commented 4 days ago

@dwromero the other thing that would be helpful (if you have the time), is to run it with only one quantizer and see if it still errors 🙏

lucidrains commented 3 days ago

@dwromero hey David, realized just now the local sampling won't work, as the codes will no longer be synced

could you try again on the latest?

lucidrains commented 2 days ago

@dwromero hey David again

so I think your error may be related to an issue with the quantize dropout in a distributed environment, which would also make the above solution not work. i put in a potential fix, if you are still running experiments

lucidrains commented 1 day ago

another way to avoid this issue is to offer a way to delay the expiration of the codes until all the quantizers have been invoked

dwromero commented 5 hours ago

Hi @lucidrains,

So, just to clarify, I should be able to run with the same configurations as in the formulation of this thread and it should work now? Would you like me to check that?