whitead / dmol-book

Deep learning for molecules and materials book
https://dmol.pub
Other
618 stars 121 forks source link

Data shuffling issue on page /applied/QM9.html #169

Closed awesomepgm closed 2 years ago

awesomepgm commented 2 years ago

The data shuffling done by TensorFlow seems to randomize the data each time its elements are accessed. This means that there are many duplicates between the train and test sets which could cause issues with the evaluation of the model. I apologize if this was intentional, but I can't see any way in which this is beneficial.

The offending code:

shuffled_data = data.shuffle(7000)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)

How I tested it:

I ran all of the code from the notebook up to the offending code. I looped over each element pair of the test and train sets and checked if they were identical. There were many duplicates.

Testing code:

train_list = [convert_record(d) for d in train_set]
test_list = [convert_record(d) for d in test_set]
xs = [d[0] for d in train_list]
test_xs = [d[0] for d in test_list]
for x in xs:
  for x2 in test_xs:
    if np.all(x[0]==x2[0]) and np.all(x[1]==x2[1]):
      #print offending pair
      print(x,x2)
      break
  else:
    #continue if no break
    continue
  #if there is a break, break out of both loops
  break

One output:

(array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 
       array([[-0.14904857,  1.3812039 , -0.08702838],
       [ 0.1812488 ,  0.01059157,  0.08730909],
       [ 0.7283333 , -0.586212  , -1.0655012 ],
       [ 1.2533858 , -1.9072461 , -0.72668207],
       [ 1.7073385 , -2.9862874 , -0.46708187],
       [-0.27883947, -0.6897319 , -2.1567938 ],
       [-1.0674677 , -0.7262004 , -2.999723  ],
       [ 0.73436487,  1.970557  , -0.3732077 ],
       [-0.93057024,  1.5173428 , -0.8457413 ],
       [-0.51839566,  1.7318261 ,  0.87758136],
       [ 1.5523492 ,  0.0340934 , -1.4610858 ],
       [ 2.101044  , -3.9448798 , -0.2318701 ]], dtype=float32)) 
(array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 
       array([[-0.14904857,  1.3812039 , -0.08702838],
       [ 0.1812488 ,  0.01059157,  0.08730909],
       [ 0.7283333 , -0.586212  , -1.0655012 ],
       [ 1.2533858 , -1.9072461 , -0.72668207],
       [ 1.7073385 , -2.9862874 , -0.46708187],
       [-0.27883947, -0.6897319 , -2.1567938 ],
       [-1.0674677 , -0.7262004 , -2.999723  ],
       [ 0.73436487,  1.970557  , -0.3732077 ],
       [-0.93057024,  1.5173428 , -0.8457413 ],
       [-0.51839566,  1.7318261 ,  0.87758136],
       [ 1.5523492 ,  0.0340934 , -1.4610858 ],
       [ 2.101044  , -3.9448798 , -0.2318701 ]], dtype=float32))

As you can see, the first and second outputs are the same despite one coming from the train set and the other from the test set. There are many more than just this one (on the scale of hundreds). Please let me know if I am missing something.

whitead commented 2 years ago

Thanks for spotting this mistake @awesomepgm! I changed to and added you to contributors section.

# we'll just train on 5,000 and use 1,000 for test
# shuffle, but only once (reshuffle_each_iteration=False) so
# we lock in which are train/test/val
shuffled_data = data.shuffle(7000, reshuffle_each_iteration=False)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)
awesomepgm commented 2 years ago

Sorry but if you don't mind can you change my name? I forgot I did not have my real name displayed on my Github page. It should now be there, but it is Parsa Pourghasem. I also wanted to mention that I have not tested it, but it seems to me that chapter 8, "Graph Neural Networks", might also have the same problem as it has the same shuffling at 8.12.

Here is the offending code:

data = tf.data.Dataset.from_generator(
    generator,
    output_signature=(
        tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),
    ),
).shuffle(1000)

# The shuffling above is really important because this dataset is in order of labels!

val_data = data.take(100)
test_data = data.skip(100).take(100)
train_data = data.skip(200)
whitead commented 2 years ago

Thanks @awesomepgm! I looked through it too, don't think I have any more examples like that. Got this one fixed too and updated your name