ageron / handson-ml

⛔️ DEPRECATED – See https://github.com/ageron/handson-ml3 instead.
Apache License 2.0
25.19k stars 12.91k forks source link

Ch3 - SGDClassifier predicts 8's, not 5's! #370

Closed RoryWatts closed 6 months ago

RoryWatts commented 5 years ago

Hi there,

Firstly, I think the book is sensational, thank you very much.

I haven't been successful in obtaining the MNIST dataset the traditional way. Instead I've obtained it from http://yann.lecun.com/exdb/mnist/ , and unpacked gzip files into their respect X_train...y_test variables, using the following:

import gzip
import numpy as np
with gzip.open('./mnist/train-images-idx3-ubyte.gz', 'rb') as f:
    file_content = f.read()

X_train = np.frombuffer(file_content, dtype='int8')[16:].reshape(60000,784)

I can validate that labels match images, by using the plt.imshow methods from the textbook. However, when I train a simple SGDClassifier the cross_val_score is worse than expected: array([0.8362 , 0.8265 , 0.81255])

Looking at the some of its predictions, it seems to like eights! If I ask it to predict the X_train images where the label == 8, it predicts many more of these as Five's than it does labels with 5.

It's a confusing result, and my thought is that it lies with something I've done reading in the data, or the nature of the data itself.

I've attached an reference image below, which was generated using the same parameters as in the book (using imshow). And the code below demonstrates how I've unpacked the image files:

image

ageron commented 5 years ago

Hi @RoryWatts , Did you try loading MNIST with fetch_openml(), as shown in the notebook? The fetch_mldata() function does not work anymore, unfortunately, because mldata.org is gone. I'm not sure why you are getting disappointing results:

Hope this helps.

RoryWatts commented 5 years ago

Many thanks @ageron , I tried downloading MNIST as you suggest, with fetch_openml(), but it results in a timeout error.

Thanks for the suggestions, images had been shuffled, but i'll attempt scaling the pixels. In the meantime, I have had success using the MNIST dataset (.csv) from kaggle.