registered buffers' dtype is overridden after __init__ #18982

Open MF-FOOM opened 10 months ago

MF-FOOM commented 10 months ago

Bug description

If I register a float64 tensor to a buffer in the __init__ function of a LightningModule like so:

self.register_buffer("testing_variable", torch.tensor([1,2,3], dtype=torch.float64))

It will get cast into the Trainer's precision type after setup, regardless of whether a different dtype (e.g. float64) was intended.

What version are you seeing the problem on?


How to reproduce the bug

Error messages and logs



cc @borda @carmocca @justusschock @awaelchli

carmocca commented 10 months ago

This is caused by this line:

We could:

I don't see a perfect solution here, there will always be edge cases. What's your opinion @awaelchli?

awaelchli commented 8 months ago

In the provided code, the user chose precision="bf16-true" which is the explicit way of saying "I want everything in bfloat16", and this is what Lightning does. Excluding the buffers by default would be a very arbitrary choice for the framework to do. Besides, if this were the case then the user would likely have to change their code in forward where they use the buffers.

In our documentation, we could make this clear by mentioning the buffers in this sentence:

If there is a strong desire to exclude buffers, one could add an flag to the HalfPrecision plugin:

trainer = Trainer(precision=HalfPrecision(buffers=False))
carmocca commented 8 months ago

Adding a flag makes sense to me. Similarly how one could want to control the output dtype which would also be a flag