snowkylin / tensorflow-handbook

简单粗暴 TensorFlow 2 | A Concise Handbook of TensorFlow 2 | 一本简明的 TensorFlow 2 入门指导教程
https://tf.wiki
3.94k stars 841 forks source link

MNISTLoader类中的get_batch方法取出的数据存在重复项 #42

Open opensourcedigest opened 4 years ago

opensourcedigest commented 4 years ago

将代码:

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        return self.train_data[index, :], self.train_label[index]

改为:

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        # index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        index = np.random.choice(np.shape(self.train_data)[0], batch_size, replace=False)
        return self.train_data[index, :], self.train_label[index]

可避免每次获取的数据中不存在重复项。

huan commented 4 years ago

Thank you for the suggestion, could you please kindly submit a Pull Request to fix that?

Appreciate!