Hello @awaelchli, Can I solve this issue? I am a beginner in open-source but would be really happy to help here. Thanks
@awaelchli @Bhavay-2001 Well, which solution do you think is best? There's not one that's immediately obvious to me as the best one to implement.
I can see a couple possible solutions:
- Simply add a warning about logging low precision types when on_epoch=True is enabled
- Use torch reduction operations to mitigate associativity issues
- Add a warning & use the torch reduction operations
- Auto cast logged values to float32(or 64?) under the hood
Hi @MF-FOOM, I am too a beginner in open-source. I am not sure about this and would like to discuss this with @awaelchli
Here is the relevant code where the accumulation happens.
I vote for converting floating point scalars to full precision before accumulation and against storing all values, because the user doesnt care about the internal representation but rather just the final reduced value. I vote against warnings because the user ignores warnings generally.
So my suggestion is to call .float() before the summation.
As a workaround in the current version, the user can do the same before passing the value to the log method.
sgtm! @Bhavay-2001 do you still want to give it a try?
Hi @MF-FOOM , maybe you can go ahead with this. I'm facing a little problem understanding it.
@awaelchli can you suggest some open source contributions for beginners?
@Bhavay-2001 Thanks for your interest. I suggest that you join the "want-to-contribute" Discord channel and we can find something that fits you.
Bug description
When training with a low precision type (fp16, bf16, etc) logging loss/etc values via
self.log(..., on_epoch=True)
will yield really inaccurate reductions (whether mean, sum, etc).This is because instead of using the torch functions for these operations (
, etc), lightning currently does the reduction manually, simply adding up new values as they're logged (and then dividing at the end in the case ofmean
).The issue with this is that, with low precision types, float non-associativity becomes a really big deal and the accumulated values can get stuck if logged values aren't large enough to push the accumulator to the next representable number (i.e. since
256 + 1 == 256
with bfloat16).torch.mean
, etc all help mitigate this under the hood (i.e. such thattorch.sum([256, 1, 1]) == 258
instead of getting stuck at256
), but since lightning does not use these functions, precision greatly suffers.However, even if we were to refactor the accumulation logic to use these torch operations, I still worry doing reduction on these small types is simply not precise enough, and is an easy trap for users to fall into without noticing. I personally have been casting my loss values to float32 to remedy this.
I can see a couple possible solutions:
is enabledWhat version are you seeing the problem on?
How to reproduce the bug
I've demonstrated how bad these precision issues can be here:
Observe that I'm logging a fixed constant value of
on each validation step, yet the reduced value comes out to0.51171875
Error messages and logs
Current environment
