lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.53k stars 204 forks source link

VectorQuantize not JIT-safe #8

Open JackMcCoy opened 2 years ago

JackMcCoy commented 2 years ago

The current code controls the program based on the values of tensors (like the "initted" buffer) and will not work when compiling a jit trace.

lucidrains commented 2 years ago

@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic

JackMcCoy commented 2 years ago

@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic

Yeah, that was my thought. I set kmeans_init and use_cosine_sim to False (though I think both are good features/ just didn't see a straightforward way of using them) and thought perhaps running jit.trace with checks would work, and it ran, but didn't train properly (maybe something else going on here, so I'm not saying this is conclusive). I was going to try copying the code into my repo and deleting the "initted" variable/check, but then decided to try running it without JIT compiling that network and it wasn't any slower, so didn't test it further.

Other libraries have tensor-based control flow functions, but I don't see anything like that for PyTorch.

JackMcCoy commented 2 years ago

Another option would seem to be putting the control flow handling in separate methods which are then tagged with @torch.jit.ignore

lucidrains commented 2 years ago

@JackMcCoy ok, let's try the torch.jit.ignore! see if the latest version helps

JackMcCoy commented 2 years ago

Unfortunately, no. Looking at the code again, ema_inplace() is another issue standing in the way... Obviously that's an important performance choice. Keeping it and having a JIT-safe routing might end up having to look fairly messy. Maybe a separate version is the best option, if it's worthwhile to do.

Any thoughts? Looking around, it seems that JIT would possibly fuse the out-of-place operations. I may poke around/ will be sure to comment on anything I find.

fractaldna22 commented 2 years ago

neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P

JackMcCoy commented 2 years ago

neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P

you know, there are ways to use a quantized codebook besides with clip!

danieltudosiu commented 2 years ago

In our internal version of VQ-VAE we had the same issue with DDP syncing the codebooks and we resorted to the following solution:

    @torch.jit.unused
    def ddp_sync(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
        if self._ddp_sync and torch.distributed.is_initialized():
            torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
        else:
            pass

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            ....
                if self._ddp_sync:
                    self.ddp_sync(encodings_sum, dw)

Basically, anything that is not jitable should be able to be turned off and avoided with if statements.