tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.5k stars 3.5k forks source link

Bair Robot Pushing SV2P Evaluation Metrics #1767

Open tungalbert99 opened 4 years ago

tungalbert99 commented 4 years ago

Description

When running SV2P algorithm on the BAIR Robot Pushing dataset, having the evaluation metrics not be an empty array means that the model crashes when trying to evaluate. After saving the model, it attempts to run evaluation, which should include PSNR and SSIM, but instead generates a tensor mismatch error.

...

Environment information

OS: Ubuntu 18.04

For bugs: reproduction and error logs

Steps to reproduce:

1) Generate BAIR robot pushing dataset
2) Run SV2P

Error logs:

image

sunhaoyuan3310 commented 2 years ago

I meet the same problem, and I sovled it by replacing function reduce_dimensions in file tensor2tensor/utils/metrics.py.

Origin function: (assuming you are using t2t 1.13.0)

def reduce_dimensions(predictions, labels):
    """Reduce dimensions for high-dimensional predictions and labels."""
    # We will treat first dimensions as batch. One example are video frames.
    if len(predictions.get_shape()) > 5:
      predictions_shape = common_layers.shape_list(predictions)
      predictions = tf.reshape(
          predictions, [predictions_shape[0], predictions_shape[1], -1,
                        predictions_shape[-1]])
      labels_shape = common_layers.shape_list(labels)
      labels = tf.reshape(
          labels, [labels_shape[0], labels_shape[1], -1])
    return predictions, labels

current function: (In fact, this code is copied from t2t 1.7.0)

def reduce_dimensions(predictions, labels):
    """Reduce dimensions for high-dimensional predictions and labels."""
    # We will treat first dimensions as batch. One example are video frames.
    if len(predictions.get_shape()) > 5:
      predictions = tf.reshape(
          predictions, [-1] + common_layers.shape_list(predictions)[-4:])
    if len(labels.get_shape()) > 4:
      labels = tf.reshape(
          labels, [-1] + common_layers.shape_list(labels)[-3:])
    return predictions, labels

and it worked.