keras-team / keras-preprocessing

Utilities for working with image data, text data, and sequence data.
Other
1.02k stars 444 forks source link

Image data flows modify Numpy's global RNG #342

Open a-cass opened 3 years ago

a-cass commented 3 years ago

Describe the problem

When using the ImageDataGenerator.flow* methods to yield image batches the seed parameter modfies Numpy's global random number generator. Similar behaviour in has been identified in other parts of the Keras library e.g. this issue.

Any calls to numpy.random.* after a batch is yielded (and the global seed is set) return the same numbers. In my case I wanted to select and view random images from a batch and was seeing that the same images were constantly being selected. I include an example below in which I am using the flow_from_data_frame method to load 8-bit RGB images from my local disk.

Example

I submitted a question to Data Science Stack Exchange after seeing this behavior in which I include a worked example. The code is below but see SE for more information.

# Step 1
# Set up image data flow
img_generator = ImageDataGenerator(rescale=1/255.)
train_gen = img_generator.flow_from_dataframe(
                img_df, # filnames are read from column "filename"
                img_dir, # local directory containing image files
                y_col=None,
                target_size=(512,512),
                class_mode=None,
                shuffle=False, # I'm using separate mask images so no shuffling here
                batch_size=16,
                seed=42 # behavior occurs when using seed
            )

# Step 2
# Generate and print 8 random indices
# No batch of images retrieved yet; no use of seed
print(np.random.randint(16, size=8))
>>> [ 7 15 13  3  6  3  2 14] # always random

# Step 3
# Now get a batch of images; seed is used
batch = next(train_gen)

# Step 4
# Generate and print 8 random indices
print(np.random.randint(16, size=8))
>>> [ 6  1  3  8 11 13  1  9] # always the same result

Proposed Solution

It appears that the culprit is the base Iterator class, specifically the _flow_index method. Similar to the approach taken in the Keras repo (PR 12259) I would suggest implementing a local RNG.

System information

Environment checklist