keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.28k stars 19.38k forks source link

Memory leak when using custom DataGenerator #19907

Open Omitg24 opened 2 weeks ago

Omitg24 commented 2 weeks ago

For the past 3 weeks I've been searching nonstop for a solution to this problem, when training a LSTM model with a custom DataGenerator, Keras ends up using all my RAM memory. The context of the project is to predict sleep stages, in this script, its expected to paralelyze 15 different participants with its 10 folds (10 train and 10 validation), and in a following phase test with its respective partition. Having said that, this is the LSTM Network I'm currently using:

I'm using:

This network has been used in this project

def create_lstm1(number_inputs, window_size):
    model = Sequential()

    model.add(Input(shape=(window_size, number_inputs)))

    model.add(LSTM(units=50, return_sequences=True))
    model.add(Dropout(0.2))

    model.add(LSTM(units=100, return_sequences=False))
    model.add(Dropout(0.2))

    model.add(Dense(units=ProjectConfig().n_phases, activation='softmax'))

    model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=['accuracy'])
    return model

Then, I've implemented this custom DataGenerator which suites my problem.

import math
from statistics import mode

import keras
import tracemalloc
import gc
import numpy as np
import pandas as pd
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import StandardScaler, LabelEncoder

from src.utils import utils

class DataGenerator(keras.utils.Sequence):

    def __init__(self,
                 x_data: pd.DataFrame,
                 y_data: pd.DataFrame,
                 name: str = "DataGenerator",
                 window_size: int = 30,
                 window_overlap: int = 15,
                 batch_size: int = 32,
                 lstm_mode: int = 1,
                 n_clases: int = 5,
                 sample_frequency: int = 10,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        super().__init__(**kwargs)
        self.x_data = x_data
        self.y_data = y_data
        self.name = name
        self.window_size = window_size
        self.window_overlap = window_overlap
        self.batch_size = batch_size
        self.lstm_mode = lstm_mode
        self.sample_frequency = sample_frequency
        self.shuffle = shuffle
        self.is_training = is_training

        current, peak = tracemalloc.get_traced_memory()
        print(f"\t\t\t\t{name} Memory before Oversampling: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
        self.x_data, self.y_data = self._oversample(self.x_data, self.y_data)

        current, peak = tracemalloc.get_traced_memory()
        print(f"\t\t\t\t{name} Memory after Oversampling: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
        self.x_data, self.y_data = self._create_windows(self.x_data, self.y_data)

        current, peak = tracemalloc.get_traced_memory()
        print(f"\t\t\t\t{name} Memory after Windows Creation: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
        encoder = LabelEncoder()
        self.y_data = encoder.fit_transform(self.y_data)
        self.y_data = keras.utils.to_categorical(self.y_data, num_classes=n_clases)
        self.indexes = self.on_epoch_end()

    def __len__(self):
        n = len(self.indexes)
        return math.ceil(n / self.batch_size) if n > 0 else 0

    def __getitem__(self, index):
        windows_x = []
        windows_y = []

        min_val = max(0, min(len(self.indexes), index * self.batch_size))
        max_val = min(len(self.indexes), (index + 1) * self.batch_size)
        for val in range(min_val, max_val):
            x_data = self.x_data[self.indexes[val]].copy()
            y_data = self.y_data[self.indexes[val]]
            x_data = self._scale_data(x_data)
            x_data = self._update_mode(x_data)

            windows_x.append(x_data)
            windows_y.append(y_data)

        batch_x = np.array(windows_x)
        batch_y = np.array(windows_y)

        del windows_x, windows_y, x_data, y_data
        gc.collect()

        return batch_x, batch_y

    def _create_windows(self, x_data, y_data):
        windows_x = []
        windows_y = []
        n_windows = math.ceil(len(x_data) / (self.window_overlap * self.sample_frequency))
        for w in range(n_windows):
            start_index = w * self.window_overlap * self.sample_frequency
            end_index = min(len(x_data), start_index + (self.window_size * self.sample_frequency))
            indexes = list(range(start_index, end_index))
            if len(indexes) == (self.window_size * self.sample_frequency):    # All windows must be equal
                X = x_data.iloc[indexes]
                Y = mode(y_data[indexes])

                windows_x.append(X)
                windows_y.append(Y)

        return windows_x, windows_y

    def _oversample(self, x_data, y_data):
        if self.is_training:
            sm = SMOTE(random_state=0, sampling_strategy='not majority')
            X_resampled, y_resampled = sm.fit_resample(x_data, y_data)
            return X_resampled, y_resampled
        return x_data, y_data

    def _update_mode(self, x_data: pd.DataFrame):
        if self.lstm_mode == 2:
            magacc = utils.calculate_magnitude(x_data['accx'], x_data['accy'], x_data['accz'])
            ret = x_data.copy()
            ret['magacc'] = magacc
        elif self.lstm_mode == 3:
            magacc = utils.calculate_magnitude(x_data['accx'], x_data['accy'], x_data['accz'])
            ret = x_data[['hr']].copy()
            ret['magacc'] = magacc
        else:
            ret = x_data.copy()
        return ret

    def on_epoch_end(self):
        indexes = np.arange(len(self.x_data))
        if self.shuffle:
            np.random.shuffle(indexes)
        return indexes

    @staticmethod
    def _scale_data(x_data):
        scaler = StandardScaler()
        scaled_data = scaler.fit_transform(x_data)
        return pd.DataFrame(scaled_data, columns=x_data.columns)

And finally, the training phase is the following:

import gc
import multiprocessing
import os
import shutil
import tracemalloc

import luigi
import numpy as np
import pandas as pd

from . import DeepPartitioning
from .data_generator import DataGenerator
from .lstm_creation import *
from ..utils import ProjectConfig, utils

class DeepTraining(luigi.Task):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.results_path = ProjectConfig().deep_training_path
        self.inputs_by_mode = {
            1: 4,  # accx, accy, accz, hr
            2: 5,  # accx, accy, accz, magnitude, hr
            3: 2  # magnitude, hr
        }

    def requires(self):
        return DeepPartitioning()

    def output(self):
        return luigi.LocalTarget(os.path.join(self.results_path, "output_paths.txt"))

    def run(self):
        os.makedirs(self.results_path, exist_ok=True)
        prev_files = utils.get_prev_files_path(self.input().path)

        patient_partitions = self.get_partitions(prev_files)

        path_list = self.run_experiments(patient_partitions)
        utils.create_output_paths_file(self.results_path, path_list)

    def run_experiments(self, partitions):
        path_list = []

        with multiprocessing.Pool(processes=multiprocessing.cpu_count(), maxtasksperchild=1) as pool:
            results = pool.map(self.process_participant_wrapper, partitions.items())

        for result in results:
            path_list.extend(result)

        return path_list

    def process_participant_wrapper(self, participant_partitions):
        participant, partitions = participant_partitions
        return self.process_participant(participant, partitions)

    def process_participant(self, participant, participant_partitions):
        path_list = []
        os.makedirs(os.path.join(self.results_path, participant), exist_ok=True)
        patient_idx = participant.split("_")[1]
        train_folds, validation_folds, test = self._get_folds(participant_partitions, patient_idx)
        model = self.create_neural_network()
        print(f"\tStarting participant {patient_idx}")

        for ep in range(ProjectConfig().n_epochs):
            epoch_path = os.path.join(self.results_path, participant, f"epoch_{ep}")
            os.makedirs(epoch_path, exist_ok=True)
            print(f"\t\tStarting epoch {ep}")
            history = []
            for k in range(ProjectConfig().n_splits):
                train_fold = pd.read_csv(train_folds[k], sep=";")
                validation_fold = pd.read_csv(validation_folds[k], sep=";")
                hist = self.train_fold(model, train_fold, validation_fold, participant, ep, k)
                history.append(hist)

                current, peak = tracemalloc.get_traced_memory()
                print(
                    f"\t\t\tMemory usage after fold {k}: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")

                del train_fold
                del validation_fold
                gc.collect()
            final_history = {}
            for key in history[0].history.keys():
                final_history.update({key: np.concatenate([hist.history[key] for hist in history])})

            current, peak = tracemalloc.get_traced_memory()
            print(f"\t\tMemory usage after epoch {ep}: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
            history_path = os.path.join(epoch_path, f"history_{participant}_{ep}_all_folds.csv")
            pd.DataFrame.from_dict(final_history).to_csv(history_path, sep=";")
            path_list.append(history_path)
            print(f"\t\tFinished epoch {ep}")
            print(f"\t\tSaving lstm ({ProjectConfig().neural_network}) weights")
            weights_path = os.path.join(epoch_path, f"weights_{patient_idx}_{ep}.weights.h5")
            model.save_weights(weights_path)
            path_list.append(weights_path)
        print(f"\tFinished participant {patient_idx}")
        print(f"\tSaving lstm ({ProjectConfig().neural_network}) model")
        model_path = os.path.join(self.results_path, participant, f"model_{patient_idx}.keras")
        model.save(model_path)
        path_list.append(model_path)

        test_path = os.path.join(self.results_path, participant, f"test_{patient_idx}.csv")
        shutil.copyfile(test, test_path)
        path_list.append(test_path)

        current, peak = tracemalloc.get_traced_memory()
        print(f"\tMemory usage after all epochs: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")

        del model
        gc.collect()
        return path_list

    def create_neural_network(self):
        mode = ProjectConfig().lstm_mode
        neural_network = ProjectConfig().neural_network
        number_inputs = self.inputs_by_mode[mode]
        window_size = (ProjectConfig().w_size * ProjectConfig().sample_frequency)

        # All neural networks are already compiled
        if neural_network == 1:
            print("Creating LSTM1 model")
            return create_lstm1(number_inputs, window_size)
        elif neural_network == 2:
            print("Creating LSTM2 model")
            return create_lstm2(number_inputs, window_size)
        elif neural_network == 3:
            print("Creating LSTM3 model")
            return create_lstm3(number_inputs, window_size, ProjectConfig().n_phases)

    def train_fold(self, model, train_fold, validation_fold, participant, ep, k):
        print(f"\t\t\tTraining model for participant {participant} - epoch {ep} - fold {k}")
        X_train, y_train = (train_fold[['accx', 'accy', 'accz', 'hr']], train_fold['stage'])
        X_validation, y_validation = (validation_fold[['accx', 'accy', 'accz', 'hr']], validation_fold['stage'])
        hist = self.train_model(model, X_train, y_train, X_validation, y_validation)
        return hist

    @staticmethod
    def train_model(model, x_train, y_train, x_validation, y_validation):
        train_generator = DataGenerator(x_data=x_train,
                                        y_data=y_train,
                                        name="train",
                                        window_size=ProjectConfig().w_size,
                                        window_overlap=ProjectConfig().w_overlapping,
                                        lstm_mode=ProjectConfig().lstm_mode,
                                        sample_frequency=ProjectConfig().sample_frequency,
                                        n_clases=ProjectConfig().n_phases)
        validation_generator = DataGenerator(x_data=x_validation,
                                             y_data=y_validation,
                                             name="validation",
                                             window_size=ProjectConfig().w_size,
                                             window_overlap=ProjectConfig().w_overlapping,
                                             lstm_mode=ProjectConfig().lstm_mode,
                                             n_clases=ProjectConfig().n_phases,
                                             sample_frequency=ProjectConfig().sample_frequency,
                                             is_training=False)

        del x_train
        del y_train
        del x_validation
        del y_validation

        hist = model.fit(train_generator,
                         steps_per_epoch=train_generator.__len__(),
                         epochs=1,
                         validation_data=validation_generator,
                         validation_steps=validation_generator.__len__())

        del train_generator
        del validation_generator
        return hist

    @staticmethod
    def get_partitions(prev_files):
        partitions = {}
        for prev_file in prev_files:
            patient = os.path.basename(os.path.dirname(prev_file))
            file_name = os.path.basename(prev_file).split(".")[0]
            partitions.setdefault(patient, {}).setdefault(file_name, prev_file)
        return partitions

    @staticmethod
    def _get_folds(partitions, patient_idx):
        train_folds = []
        validation_folds = []
        for i in range(ProjectConfig().n_splits):
            train_folds.append(partitions[f"train_fold_{patient_idx}_{i}"])
            validation_folds.append(partitions[f"validation_fold_{patient_idx}_{i}"])
        test = partitions[f"test_participant_{patient_idx}"]
        return train_folds, validation_folds, test

With that, I have this output file (I'm showing the first and the last epoch) where we can see how it ends up spending 80GBs of RAM on just one participant with 10 epochs and 10 folds.

# FIRST EPOCH
Starting participant 0
        Starting epoch 0
            Training model for participant patient_0 - epoch 0 - fold 0
Memory usage after fold 0: Current = 1033.934569 MB; Peak = 1556.608261 MB
            Training model for participant patient_0 - epoch 0 - fold 1
Memory usage after fold 1: Current = 1854.543623 MB; Peak = 2381.11604 MB
            Training model for participant patient_0 - epoch 0 - fold 2
Memory usage after fold 2: Current = 2675.151555 MB; Peak = 3201.725061 MB
            Training model for participant patient_0 - epoch 0 - fold 3
Memory usage after fold 3: Current = 3495.760576 MB; Peak = 4022.326754 MB
            Training model for participant patient_0 - epoch 0 - fold 4
Memory usage after fold 4: Current = 4316.366685 MB; Peak = 4842.94543 MB
            Training model for participant patient_0 - epoch 0 - fold 5
Memory usage after fold 5: Current = 5136.99149 MB; Peak = 5663.567691 MB
            Training model for participant patient_0 - epoch 0 - fold 6
Memory usage after fold 6: Current = 5957.613007 MB; Peak = 6484.199312 MB
            Training model for participant patient_0 - epoch 0 - fold 7
Memory usage after fold 7: Current = 6778.235883 MB; Peak = 7304.812643 MB
            Training model for participant patient_0 - epoch 0 - fold 8
Memory usage after fold 8: Current = 7598.856964 MB; Peak = 8125.438533 MB
            Training model for participant patient_0 - epoch 0 - fold 9
Memory usage after fold 9: Current = 8419.47895 MB; Peak = 8946.065265 MB
Memory usage after fold 9: Current = 8210.1733 MB; Peak = 8946.065265 MB
        Finished epoch 0

# LATEST EPOCH
        Starting epoch 9
            Training model for participant patient_0 - epoch 9 - fold 0
Memory usage after fold 0: Current = 74889.296057 MB; Peak = 75415.873067 MB
            Training model for participant patient_0 - epoch 9 - fold 1
Memory usage after fold 1: Current = 75709.918092 MB; Peak = 76236.499263 MB
            Training model for participant patient_0 - epoch 9 - fold 2
Memory usage after fold 2: Current = 76530.522405 MB; Peak = 77057.104749 MB
            Training model for participant patient_0 - epoch 9 - fold 3
Memory usage after fold 3: Current = 77351.125716 MB; Peak = 77877.709326 MB
            Training model for participant patient_0 - epoch 9 - fold 4
Memory usage after fold 4: Current = 78171.730176 MB; Peak = 78698.309015 MB
            Training model for participant patient_0 - epoch 9 - fold 5
Memory usage after fold 5: Current = 78992.352293 MB; Peak = 79518.938572 MB
            Training model for participant patient_0 - epoch 9 - fold 6
Memory usage after fold 6: Current = 79812.972282 MB; Peak = 80339.551835 MB
            Training model for participant patient_0 - epoch 9 - fold 7
Memory usage after fold 7: Current = 80633.590744 MB; Peak = 81160.172095 MB
            Training model for participant patient_0 - epoch 9 - fold 8
Memory usage after fold 8: Current = 81454.212892 MB; Peak = 81980.790902 MB
            Training model for participant patient_0 - epoch 9 - fold 9
Memory usage after fold 9: Current = 82274.83347 MB; Peak = 82801.413244 MB
Memory usage after fold 9: Current = 82065.526847 MB; Peak = 82801.413244 MB
        Finished epoch 9

I've tried to explicitly delete variables, also calling garbace collector and using clear_session() after finishing training each model, since its an incremental training, I think I'm not suposed to use it between folds.

Finally, if this could help proving my issue, I've also tried to see what would print a memory_profiler, just in case it was really freeing memory (but not the necessary), this is the result for one epoch 10 folds on one participant.


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    62    484.1 MiB    484.1 MiB           1       @profile
    63                                             def process_participant(self, participant, participant_partitions):
    64    484.1 MiB      0.0 MiB           1           path_list = []
    65    484.1 MiB      0.0 MiB           1           os.makedirs(os.path.join(self.results_path, participant), exist_ok=True)
    66    484.1 MiB      0.0 MiB           1           patient_idx = participant.split("_")[1]
    67    484.2 MiB      0.1 MiB           1           train_folds, validation_folds, test = self._get_folds(participant_partitions, patient_idx)
    68    550.3 MiB     66.1 MiB           1           model = self.create_neural_network()
    69    550.3 MiB      0.0 MiB           1           print(f"\tStarting participant {patient_idx}")        
    70    550.3 MiB      0.0 MiB           1           tracemalloc.start()
    71                                         
    72   5377.8 MiB      0.0 MiB           2           for ep in range(ProjectConfig().n_epochs):
    73    550.3 MiB      0.0 MiB           1               epoch_path = os.path.join(self.results_path, participant, f"epoch_{ep}")
    74    550.3 MiB      0.0 MiB           1               os.makedirs(epoch_path, exist_ok=True)
    75    550.3 MiB      0.0 MiB           1               print(f"\t\t{participant}Starting epoch {ep}")
    76    550.3 MiB      0.0 MiB           1               history = []
    77   6433.6 MiB -13257.3 MiB          11               for k in range(ProjectConfig().n_splits):
    78   6523.9 MiB  -9093.3 MiB          10                   train_fold = pd.read_csv(train_folds[k], sep=";")
    79   6525.7 MiB -10294.1 MiB          10                   validation_fold = pd.read_csv(validation_folds[k], sep=";")
    80   6433.6 MiB  -8830.9 MiB          10                   hist = self.train_fold(model, train_fold, validation_fold, participant, ep, k)
    81   6433.6 MiB -13358.4 MiB          10                   history.append(hist)
    82                                         
    83   6433.6 MiB -13358.3 MiB          10                   current, peak = tracemalloc.get_traced_memory()
    84   6433.6 MiB -13358.2 MiB          10                   print(f"\t\t\t{participant}Memory usage after fold {k}: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
    85                                         
    86   6433.6 MiB -13168.4 MiB          10                   del train_fold
    87   6433.6 MiB -13156.4 MiB          10                   del validation_fold
    88   6433.6 MiB -13257.3 MiB          10                   gc.collect()
    89   5373.1 MiB  -1060.5 MiB           1               final_history = {}
    90   5373.1 MiB      0.0 MiB           5               for key in history[0].history.keys():
    91   5373.1 MiB      0.1 MiB          52                   final_history.update({key: np.concatenate([hist.history[key] for hist in history])})
    92                                         
    93   5373.1 MiB      0.0 MiB           1               current, peak = tracemalloc.get_traced_memory()
    94   5373.1 MiB      0.0 MiB           1               print(f"\t\t{participant}Memory usage after epoch {ep}: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
    95   5373.1 MiB      0.0 MiB           1               history_path = os.path.join(epoch_path, f"history_{participant}_{ep}_all_folds.csv")
    96   5375.3 MiB      2.2 MiB           1               pd.DataFrame.from_dict(final_history).to_csv(history_path, sep=";")
    97   5375.3 MiB      0.0 MiB           1               path_list.append(history_path)
    98   5375.3 MiB      0.0 MiB           1               print(f"\t\t{participant}Finished epoch {ep}")
    99   5375.3 MiB      0.0 MiB           1               print(f"\t\t{participant}Saving lstm ({ProjectConfig().neural_network}) weights")
   100   5375.3 MiB      0.0 MiB           1               weights_path = os.path.join(epoch_path, f"weights_{patient_idx}_{ep}.weights.h5")
   101   5377.8 MiB      2.4 MiB           1               model.save_weights(weights_path)
   102   5377.8 MiB      0.0 MiB           1               path_list.append(weights_path)
   103   5377.8 MiB      0.0 MiB           1           print(f"\t{participant}Finished participant {patient_idx}")
   104   5377.8 MiB      0.0 MiB           1           print(f"\t{participant}Saving lstm ({ProjectConfig().neural_network}) model")
   105   5377.8 MiB      0.0 MiB           1           model_path = os.path.join(self.results_path, participant, f"model_{patient_idx}.keras")
   106   5378.9 MiB      1.1 MiB           1           model.save(model_path)
   107   5378.9 MiB      0.0 MiB           1           path_list.append(model_path)
   108                                         
   109   5378.9 MiB      0.0 MiB           1           test_path = os.path.join(self.results_path, participant, f"test_{patient_idx}.csv")
   110   5378.9 MiB      0.0 MiB           1           shutil.copyfile(test, test_path)
   111   5379.0 MiB      0.1 MiB           1           path_list.append(test_path)
   112                                         
   113   5379.0 MiB      0.0 MiB           1           current, peak = tracemalloc.get_traced_memory()
   114   5379.0 MiB      0.0 MiB           1           print(f"\t{participant}Memory usage after all epochs: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
   115                                         
   116   5379.0 MiB      0.0 MiB           1           del model
   117   5379.0 MiB      0.0 MiB           1           gc.collect()
   118   5473.8 MiB     94.9 MiB           1           clear_session()
   119   5473.8 MiB      0.0 MiB           1           current, peak = tracemalloc.get_traced_memory()
   120   5473.8 MiB      0.0 MiB           1           print(f"\tMemory usage after patient {participant}: Current = {current / 10 ** 6} MB; Peak = {peak / 10 ** 6} MB")
   121                                         
   122   5680.4 MiB    206.5 MiB           1           tracemalloc.stop()
   123   5680.4 MiB      0.0 MiB           1           return path_list

Hope someone knows how to fix this issue. Thanks!

What I've tried

I've tried reading the folds just when needed, explicitly freeing memory by deleting variables and calling garbage_collector, using different techniques of paralelization, but I've always faced the issue of one single participant consuming too much memory to handle.

dryglicki commented 1 week ago

I have nothing to provide you but solidarity. I am running into this same problem with a TFRecords data pipeline:

def _parse_function(example_proto):
    feature_description = {
            'ny'  : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'nx'  : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'ntp' : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'ntf' : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'ncp' : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'ncf' : tf.io.FixedLenFeature([], tf.int64, default_value = 0),
            'priors' : tf.io.FixedLenFeature([], tf.string, default_value = ''),
            'forecasts' : tf.io.FixedLenFeature([], tf.string, default_value = ''),
            }
    features = tf.io.parse_example(example_proto, feature_description)
    priors = tf.io.parse_tensor(features['priors'], tf.float32)
    forecasts = tf.io.parse_tensor(features['forecasts'], tf.float32)

    ny  = features['ny']
    nx  = features['nx']
    ntp = features['ntp']
    ntf = features['ntf']
    ncp = features['ncp']
    ncf = features['ncf']

    priors = tf.reshape(priors, shape = [ntp, ny, nx, ncp])
    forecasts = tf.reshape(forecasts, shape = [ntf, ny, nx, ncf])

    return priors, forecasts

...
def create_dataset_onr_tfrecords(path,
                                 glob,
                                 batch_size = 32,
                                 compression = 'GZIP',
                                 shuffle = True,
                                 deterministic = False):
    return tf.data.Dataset.list_files(str(path / glob), shuffle = shuffle).interleave(
            lambda x: tf.data.TFRecordDataset(x, compression_type = compression),
                    cycle_length = tf.data.AUTOTUNE,
                    num_parallel_calls = tf.data.AUTOTUNE,
                    deterministic = deterministic
                    ).map(
                        _parse_function,
                        num_parallel_calls = tf.data.AUTOTUNE
                    ).batch(
                        batch_size, drop_remainder = True
                    ).prefetch(tf.data.AUTOTUNE)

I'll spare you the plot, but I am having the same issue with a vanilla TF dataset. I've tried removing interleave, removing GZIP compression, calling TFRecordDataset directly, removed batching, removed prefetching... nothing.

I believe this is a Tensorflow problem and (in particular) a TF Dataset problem: https://github.com/tensorflow/tensorflow/issues/65675

This TF 2.16 + K3 era has been a disaster. Not the Keras part -- just some growing pains. But TF, man...

CZtheHusky commented 4 days ago

I am facing the same problem, using scripts from here: https://github.com/kpertsch/rlds_dataset_mod

which also involves with certain features from tensorflow dataset. The scripts is intended to do some modifications to an existing tensorflow dataset stored in TFRecord format.