keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.48k forks source link

Accessing dataloader and model output in callbacks #19022

Open innat opened 10 months ago

innat commented 10 months ago

Reopening from here.

Highlevel overview.

def on_test_batch_begin(self, batch, batch_id, prediction, logs=None):
    # init or reinit
    batch_id: int
    batch: dict
    prediction: dict

def on_test_batch_end(self, batch, batch_id, prediction, logs=None):
    batch_id: int
    batch: dict # from dataloader
        'image_array': (batch_size, height, weight, 3)
        'mask_array': (batch_size, height, weight, 3)
        'target_array': (batch_size, 3)
        'bounding_box_array': (batch_size, num_detections, 5)
        ...
    prediction: dict
        'pred_mask_array': : (batch_size, height, weight, 3)
        'pred_target_array': (batch_size, 3) 
        'pred_bounding_box_array': (batch_size, num_detections, 5)
grasskin commented 10 months ago

Hi @innat, would a custom metric work for this usecase? It should have access to model outputs as opposed to custom callbacks (which do not). https://keras.io/api/metrics/#creating-custom-metrics

innat commented 10 months ago

Hi @grasskin, thanks for checking.

would a custom metric work for this usecase?

I'm afraid, not really. Also for me it's not exactly clear how custom metrics could be related in this scenarios. I'have added some argument about the usecase for the requested feature. If possible, could you please elaborate more with some pseudo code.

SuryanarayanaY commented 9 months ago

Hi @innat ,

Whether with this PR #19041 has any use on this issue?

innat commented 9 months ago

@SuryanarayanaY Thanks for checking. That PR is not relevant to this issue.


Let's say I have dataloader (say built with tf.data API). And I can do model.predict(tf_data) to get the predictions all at once.

But at the same time, I also need each sample (x) for individual processing (maybe for visualization) or evaluate the trained model (y) with additional metrics which were not part of compile method (i.e. confusion metrics, classification report, or pycoco, etc). For that I need to unpack the tf_data to get x and y, which often overflow the system memory if the data is big enough. Unpacking tf_data back to common numpy array is one of the most asked question (~85K views) on stack overflow.

Currently we can do

model.fit(train_ds, validation_ds)
gt_samples = [samples for samples , _ in validation_ds.unbatch()]
gt_labels = [labels for _, labels in validation_ds.unbatch()]
pred_labels = model.predict(validation_ds)
scores = any_metrics(gt_labels, pred_labels)
save_visualization(
  gt_samples, 
  gt_labels,
  pred_labels 
  scores,
)

Also notice that, the model.predict on validation set shoould be redundant here, the data is already processed during the model.fit, that being said we already have pred_labels in the test_step method and not necessary to get it again from predict_step.

To overcome this, here we can subclass the model and override the test_step method to store the gt_samples, gt_labels and pred_labels to some placeholders, shown below, which can be costly or complicated to implement to many user for advance modelling, i.e. object detection, etc.

val_x = tf.Variable(...)
val_y = tf.Variable(...)
val_pred = tf.Variable(...)

class ExtendedModel(keras.Model):
    def __init__(self, model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Actual model
        self.model = model

    def test_step(self, data):
        x, y = data
        y_pred = self.model(x, training=False)
        self.compiled_loss
        self.compiled_metrics

        # Extra < --------------------
        val_x.assign(...) 
        val_y.assign(
            tf.concat([val_y , y], axis=0)
        )
        val_pred.assign(
            tf.concat([val_pred, y_pred], axis=0)
        )
        return {m.name: m.result() for m in self.metrics}

Instead, as mentioned here, I was suggesting to take a look into this issue; making these accessible from callback API, would be much cleaner, IMHO.

Let me know what do you think. It would be somewhat a big change in the API. So, if the keras-team doesn't want to include it; feel free to close the issue. Thank you.