Open xukai92 opened 4 years ago
I could be missing the point, but isn't this why logitbinarycrossentropy
exists – it's a lot more numerically stable than the unfused version?
Of course if the behaviour is better with Tracker, that suggests we can improve something in Zygote; would be good to see what the divergence is there.
Sorry I'm being note clear here. I should have said that the gradient for these two cases are very different for Zygote. For the MWE below, I'd expect the gradient is close for both and near zero.
gradient((x) -> Flux.binarycrossentroy(sigmoid(x), 0), 50)
# (0.0,)
gradient((x) -> Flux.logitbinarycrossentroy(x, 0), 50)
# (1.0,)
More information about where it starts to behaves weird can be seen from the sweep below: Hope it's clear!
That does seem suspicious. Do you know what tracker or reversediff does here?
This is how Tracker works. Bit unexpected though because as I said my training was stable with Tracker... Not sure what exacly happens here then.
This is possibly exactly what I'm observing in #876.
When x
is large, exp(-x)
<< 1
and 1 + exp(-x)
becomes 1
in float. The threshold is 16.63553f0
for Float32
and 36.7368005696771
for Float64
.
If x
exceeds this number, σ(x) = 1
and the gradient σ'(x)
become zero because it is defined as σ(x)(1 - σ(x))
in Zygote (https://github.com/FluxML/Zygote.jl/blob/84bf62ea18330389c64d0d918c91d7b897e1a5d8/src/lib/nnlib.jl#L8-L11).
This is why the behaviour of binarycrossentropy(sigmoid(x), 0)
looks weird.
using Flux
for x in (16.63553f0, 36.7368005696771)
around_x = nextfloat.(x, -3:3)
display(collect(zip(around_x, σ.(around_x))))
end
7-element Array{Tuple{Float32,Float32},1}:
(16.635525, 0.9999999)
(16.635527, 0.9999999)
(16.635529, 0.9999999)
(16.63553, 0.9999999)
(16.635532, 1.0)
(16.635534, 1.0)
(16.635536, 1.0)
7-element Array{Tuple{Float64,Float64},1}:
(36.73680056967708, 0.9999999999999998)
(36.73680056967709, 0.9999999999999998)
(36.736800569677094, 0.9999999999999998)
(36.7368005696771, 0.9999999999999998)
(36.73680056967711, 1.0)
(36.736800569677115, 1.0)
(36.73680056967712, 1.0)
for x in (16.63553f0, 36.7368005696771)
around_x = nextfloat.(x, -3:3)
display(collect(zip(around_x, σ'.(around_x))))
end
7-element Array{Tuple{Float32,Float32},1}:
(16.635525, 1.19209275e-7)
(16.635527, 1.19209275e-7)
(16.635529, 1.19209275e-7)
(16.63553, 1.19209275e-7)
(16.635532, 0.0)
(16.635534, 0.0)
(16.635536, 0.0)
7-element Array{Tuple{Float64,Float64},1}:
(36.73680056967708, 2.2204460492503126e-16)
(36.73680056967709, 2.2204460492503126e-16)
(36.736800569677094, 2.2204460492503126e-16)
(36.7368005696771, 2.2204460492503126e-16)
(36.73680056967711, 0.0)
(36.736800569677115, 0.0)
(36.73680056967712, 0.0)
If we assign a small positive gradient, we can avoid zero gradients in binarycrossentropy(sigmoid(x), 0)
.
mysigmoid(x) = one(x) / (one(x) + exp(-x))
Flux.@adjoint function mysigmoid(x)
y = mysigmoid(x)
z = ifelse(y == one(y), prevfloat(one(x)), y)
w = z * (1 - z)
return y, Δ -> (Δ * w,)
end
https://nbviewer.jupyter.org/gist/matsueushi/666c7e7e62c093d998a839017869a519
I think what @matsueushi explained in https://github.com/FluxML/Flux.jl/issues/914#issuecomment-577509073 makes sense. Any idea how what we approach to fix this @MikeInnes?
Once x
exceeds the threshold, sigmoid(x)
remains constant. It means binarycrossentropy(sigmoid(x), 0)
is also constant and its numerical derivative becomes zero, as we have seen in the behavior of Zygote.
In my opinion, we cannot rely on the value (or gradient) of binarycrossentropy(sigmoid(x), 0)
for large x
in floating-point arithmetic because even if we make some changes sigmoid
becomes constant eventually.
I would recommend to use logitBCE instead of logit + BCE, as https://github.com/tensorflow/tensorflow/issues/2462.
we should update all models in Flux's docs and in model-zoo to use a linear layer as the last layer and logitcrossentropy as loss
For reference, here's PyTorch with Float32:
import torch
import torch.nn.functional as F
zero = torch.tensor(0, dtype=torch.float)
def get_grads(i):
x = torch.tensor(i, dtype=torch.float, requires_grad=True)
bce = F.binary_cross_entropy(torch.sigmoid(x), zero)
lbce = F.binary_cross_entropy_with_logits(x, zero)
return torch.autograd.grad(bce, x), torch.autograd.grad(lbce, x)
for i in range(50):
print(get_grads(i + 1))
###
((tensor(0.7311),), (tensor(0.7311),))
((tensor(0.8808),), (tensor(0.8808),))
((tensor(0.9526),), (tensor(0.9526),))
((tensor(0.9820),), (tensor(0.9820),))
((tensor(0.9933),), (tensor(0.9933),))
((tensor(0.9975),), (tensor(0.9975),))
((tensor(0.9991),), (tensor(0.9991),))
((tensor(0.9997),), (tensor(0.9997),))
((tensor(0.9999),), (tensor(0.9999),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
Is there anything actionable here or we can close?
Possibly, if we want to implement @matsueushi's suggestion around sigmoid. Otherwise I don't think there's anything actionable.
Possibly, if we want to implement @matsueushi's suggestion around sigmoid
That asymptotic 0.5 is still wrong and possibly an harder to detect problem. Also performance and cuda compatibility should be assessed. Maybe better leave things as they are @matsueushi ?
When the logit is large, two functions can behave quite different.
I meet this issue when trianing a GAN using Zygote (which was fine using Tracker before). Switching from logitBCE to BCE stops my training from diverging. This might also be related to recent reported weired training bahaviour using Zygote in other issues.