zalandoresearch / fashion-mnist

A MNIST-like fashion product database. Benchmark :point_down:
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/
MIT License
11.93k stars 3.01k forks source link

Loading the dataset from the local path using tensorflow 2.0 #167

Open henryhuanghenry opened 3 years ago

henryhuanghenry commented 3 years ago

I downloaded the dataset from a source and placing it in an arbitrary path. And I found some people having trouble in loading the dataset from a local path using tensorflow 2.0. The API tf.keras.datasets.fashion_mnist.load_data() seems not support loading data locally.

I write a new function that may help to solve this issue. I hope that this function could help somebody in need. I don`t know whether the issue is big enough for a pull request. So I open an issue here and post my code here. Hope that I won't cause any inconvenience.

The code of new function:

import os
import numpy as np
import gzip
def load_data_fromlocalpath(input_path):
  """Loads the Fashion-MNIST dataset.
  Modified by Henry Huang in 2020/12/24.
  We assume that the input_path should in a correct path address format.
  We also assume that potential users put all the four files in the path.

  Load local data from path ‘input_path’.

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
  """
  files = [
      'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
      't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
  ]

  paths = []
  for fname in files:
    paths.append(os.path.join(input_path, fname))  # The location of the dataset.

  with gzip.open(paths[0], 'rb') as lbpath:
    y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[1], 'rb') as imgpath:
    x_train = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

  with gzip.open(paths[2], 'rb') as lbpath:
    y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[3], 'rb') as imgpath:
    x_test = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

  return (x_train, y_train), (x_test, y_test)

When calling this function:

(x_train,y_train),(x_test,y_test)=load_data_fromlocalpath('Your path')
Yanglian666 commented 3 years ago

good

guorouda commented 3 years ago

good job.

ChunJen commented 1 year ago

thanks for sharing

ali713111 commented 1 year ago
  1. Load local data from path ‘input_path’ use back ticks '''
  2. return '(x_train, y_train), (x_test, y_test)'
2484114905 commented 10 months ago

thank you

Shuangqing-Xu commented 3 months ago

Thank you for sharing, good job! :sob: