stared / livelossplot

Live training loss plot in Jupyter Notebook for Keras, PyTorch and others
https://p.migdal.pl/livelossplot
MIT License
1.29k stars 143 forks source link

ImageDataGenerator incompatible #140

Open novitae opened 1 year ago

novitae commented 1 year ago

🐛 Bug description

The real-time progress table doesn't show up when training using a dataset formatted with ImageDataGenerator. Here is my code when using ImageDataGenerator (that doesn't work):

from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rescale=1.0/255.0,
    validation_split=0.2
)

train_gen = datagen.flow_from_dataframe(
    dataframe= df,
    directory=None,
    x_col="image_path",
    y_col=training_df.columns[1:],
    subset="training",
    batch_size=32,
    shuffle=True,
    class_mode="raw",
    target_size=(224, 224),
)

valid_gen = datagen.flow_from_dataframe(
    dataframe= df,
    directory=None,
    x_col="image_path",
    y_col=training_df.columns[1:],
    subset="validation",
    batch_size=32,
    shuffle=True,
    class_mode="raw",
    target_size=(224, 224),
)
from livelossplot import PlotLossesKeras
plotlosses = PlotLossesKeras()

history = model.fit(
    train_gen,
    validation_data=valid_gen,
    epochs=10,
    callbacks=[plotlosses],
    verbose=0
)
Capture d’écran 2023-08-03 à 13 20 02

And here is the working code using a X_train, Y_train, X_val and Y_val method:

from sklearn.model_selection import train_test_split
from keras.preprocessing.image import load_img, img_to_array
import numpy as np

def load_image(image_path, target_size=(224,224)):
    img = load_img(image_path, target_size=target_size)
    img = img_to_array(img)
    img = img / 255.0  # Normalize to [0,1]
    return img

subset_df = df.sample(n=10, random_state=42)

train_df, val_df = train_test_split(subset_df, test_size=0.2, random_state=42)

X_train = np.array([load_image(path) for path in train_df['image_path']])
Y_train = train_df[train_df.columns[1:]].values  # Assumes first column is image_path, rest are labels

X_val = np.array([load_image(path) for path in val_df['image_path']])
Y_val = val_df[val_df.columns[1:]].values
from livelossplot import PlotLossesKeras
plotlosses = PlotLossesKeras()

history = model.fit(
    X_train, Y_train,
    validation_data=(X_val, Y_val),
    epochs=10,
    callbacks=[plotlosses],
    verbose=0,
)
Capture d’écran 2023-08-03 à 13 15 40

(don't worry for the weird stats, it was just a demo training on 10 images)

Environment