lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.65k stars 215 forks source link

Wrong distributed operation in LFQ #158

Closed Jason3900 closed 2 months ago

Jason3900 commented 2 months ago

https://github.com/lucidrains/vector-quantize-pytorch/blob/e99128dd6d780cc7b97cc5c4f37a99b05a834c57/vector_quantize_pytorch/lookup_free_quantization.py#L13

The maybe_distributed_mean may result in incorrect results, as applying torch.distributed.all_reduce will not flow the gradient back to each GPU using c10d communication backend. related discussion

Thus, the proposed way is to use from torch.distributed import nn as dist_nn, and use dist_nn.all_reduce operator to retain the gradient. Otherwise, a warning will show and the result is incorrect(The codebook entropy loss will be abnormal according to my experiments with LFQ).

/root/miniconda3/envs/video_tok/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: c10d::allreduce_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
lucidrains commented 2 months ago

@Jason3900 yes you are right, thank you

Jason3900 commented 2 months ago

You're welcome. Thanks for your excellent project!

lucidrains commented 2 months ago

@Jason3900 no problem, go train something amazing with it

hummat commented 2 months ago

Is this also relevant for all_reduce calls in vector_quantize_pytorch.py such as this?

https://github.com/lucidrains/vector-quantize-pytorch/blob/59304110b656e2865e15e5918a2c9e82989decb8/vector_quantize_pytorch/vector_quantize_pytorch.py#L316

lucidrains commented 2 months ago

@hummat i don't think so, because they are only used during the ema or kmeans, both without needing gradients