msyim / TensorFlowFun

Me playing with TF
0 stars 0 forks source link

Note on Logistic Regression #2

Open msyim opened 7 years ago

msyim commented 7 years ago

Method 1 in the code works fine:

# Method 1
hypothesis = tf.matmul(X,W) + b
cost = tf.nn.sigmoid_cross_entropy_with_logits(hypothesis, Y)
optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
predicted = tf.cast(hypothesis > 0.5, dtype=tf.float32)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, Y), dtype=tf.float32))

However, the second method (which I believe is essentially the same as the first) doesn't seem to work:

# Method 2
hypothesis2 = tf.sigmoid(tf.matmul(X,W) + b)
cost2 = -tf.reduce_mean(Y * tf.log(hypothesis2) + (1-Y)*tf.log(1-hypothesis2))
optimizer2 = tf.train.GradientDescentOptimizer(0.001).minimize(cost2)
predicted2 = tf.cast(hypothesis2 > 0.5, dtype=tf.float32)
accuracy2 = tf.reduce_mean(tf.cast(tf.equal(predicted2, Y), dtype=tf.float32))

I get cost: nan Need to investigate further on this issue....

msyim commented 7 years ago

Case caught:

# Method 1
iteration : 0 
cost : [[  1.01688766e+01]
            [  1.26653758e-05]
            [  1.21741109e-02]
            [  4.13549969e-07]
            [  7.84303741e+01]
            [  2.13611298e+01]
            [  2.63713766e-03]] 
hypothesis : [[-10.1688385 ]
            [ 11.27663231]
            [ -4.40235043]
            [-14.69848728]
            [ 78.43037415]
            [-21.36112976]
            [  5.93674231]] 
sig(hypothesis) : [[  3.83453662e-05]
            [  9.99987364e-01]
            [  1.21003054e-02]
            [  4.13549884e-07]
            [  1.00000000e+00]
            [  5.28419974e-10]
            [  9.97366369e-01]] 
pred : [[ 0.] [ 1.] [ 0.] [ 0.] [ 1.] [ 0.] [ 1.]] 
acc : 0.571429

For the 5-th record, sig(hyp) outputs 1, while the ground truth = 0. Using the second method, this will result in :

cost = -tf.reduce_mean(Y*log(hyp) + (1-Y)*log(1-hyp))
        = -tf.reduce_mean(log(0))
        = -inf

which is why I was getting nan costs.

However, as seen in the above, using tf.nn.sigmoid_cross_entropy_with_logits(hypothesis, Y) certainly avoids this problem. Will look into how tf.nn.sigmoid_cross_entropy_with_logits is implemented.

msyim commented 7 years ago

sigmoid_cross_entropy_with_logits defined in tensorflow/python/ops/nn.py:

def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
  """Computes sigmoid cross entropy given `logits`.

  Measures the probability error in discrete classification tasks in which each
  class is independent and not mutually exclusive.  For instance, one could
  perform multilabel classification where a picture can contain both an elephant
  and a dog at the same time.

  For brevity, let `x = logits`, `z = targets`.  The logistic loss is

        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
      = (1 - z) * x + log(1 + exp(-x))
      = x - x * z + log(1 + exp(-x))

avoids the issue by computing the last equation. Even when x is quite large, this shouldn't be too much a problem (while in the first equation, the cost will still be -inf)