rses-dl-course / rses-dl-course.github.io

Other
8 stars 4 forks source link

Broken Plots in Lab 03 #15

Closed EdwinB12 closed 1 year ago

EdwinB12 commented 1 year ago

The plotting functions appear to expect predicted values between 0-1.

e.g.

Multiplying by 100 to get a percentage for the axis label

plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label], 100*np.max(predictions_array), class_names[true_label]), color=color)

and

plt.ylim([0, 1]) in


def plot_value_array(i, predictions_array, true_label):
  predictions_array, true_label = predictions_array[i], true_label[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)

  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

This gives plots like this:

image

If the plotting limits are removed, the same plot looks like this:

image

I suspect the code was written when 'softmax' was used as a final layer in the model. Now that has been removed and replaced by 'from_logits=True', the predictions would be need some processing to make them 0-1.

I think there is an argument here for using 'softmax' instead of from_logits

EdwinB12 commented 1 year ago
image

The softmax approach indeed fixes the plotting :)