Open lucasrelic99 opened 2 months ago
Thanks for the report.
I could add a version check for torch<2.0:
from packaging.version import Version
class CheckerboardLatentCodec(LatentCodec):
def _y_ctx_zero(self, y: Tensor) -> Tensor:
if Version(torch.__version__) < Version("2.0.0"):
return self._mask(self.context_prediction(y).detach(), "all")
return y.new_zeros(self.context_prediction(y.to("meta")).shape)
...but perhaps simpler is just to revert:
class CheckerboardLatentCodec(LatentCodec):
def _y_ctx_zero(self, y: Tensor) -> Tensor:
return self._mask(self.context_prediction(y).detach(), "all")
To be fair I don't actually know which is the earliest torch version that supports meta device tensors as I couldn't find any solid information.
Although I think the simpler fix is probably good enough. On my machine with a 14900K and a 3090 and an (unreasonably large) context size of (16, 192, 512, 512)
it takes 0.06ms to execute that line on GPU. It does take about 4 seconds on CPU, but with a more reasonable context size of (16, 192, 32, 32)
it takes roughly 80ms on CPU.
Bug
Using the CheckerboardLatentCodec with a non-identity context_prediction module results in a runtime error during the forward pass. I believe this should only occur when using a torch version less than 2.0.
To Reproduce
Steps to reproduce the behavior:
CheckerboardLatentCodec
.forward()
method of the latent codec.Minimal working example:
This code results in the error:
Expected behavior
The code should not throw an error.
Environment
Output from
python3 -m torch.utils.collect_env
:Additional context
I am quite certain this is due to the fact that older pytorch versions do not support operations on tensors which are on the "meta" device. I think this was introduced with PyTorch 2.0 but I couldn't find anything definitive from a quick search.
I traced this back to commit eddb1bc, which uses meta device tensors to compute the expected size of the checkerboard context tensor. Replacing these lines with the previous version resolved the issue for me.