alexlee-gk / video_prediction

Stochastic Adversarial Video Prediction
https://alexlee-gk.github.io/video_prediction/
MIT License
303 stars 65 forks source link

Testing on custom images #45

Open BonJovi1 opened 4 years ago

BonJovi1 commented 4 years ago

Hi Alex, Thanks for releasing the code! Just had a small query, I was trying to test out one of the pre trained models on a custom dataset. Taking inspiration from kth_dataset.py, I created the .pkl file for my data, resized all my images to 64x64 and converted all of it to .tfrecords. So now, my test set looks like this:

test/ 
     sequence_0_to_9.tfrecords
     sequence_lengths.txt

The dataset is really small, just 10 sequences, each of sequence length 10.

And then, I'm trying to use the ours_savp pre-trained model that you've provided for the kth dataset. It worked for the kth dataset. But it fails on my custom dataset,. This is the command I'm running:

python scripts/generate.py --input_dir data/habitat --dataset_hparams sequence_length=2 --checkpoint 
pretrained_models/kth/ours_savp/ --mode test --results_dir results_test_samples/habitat --batch_size 1

It shoots out an error saying:

Traceback (most recent call last):
  File "scripts/generate.py", line 193, in <module>
    main()
  File "scripts/generate.py", line 135, in main
    model.build_graph(input_phs)
  File "/scratch/abhinav/video_prediction/video_prediction/models/base_model.py", line 478, in build_graph
    outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
  File "/scratch/abhinav/video_prediction/video_prediction/models/base_model.py", line 412, in tower_fn
    gen_outputs = self.generator_fn(inputs)
  File "/scratch/abhinav/video_prediction/video_prediction/models/savp_model.py", line 730, in generator_fn
    gen_outputs_posterior = generator_given_z_fn(inputs_posterior, mode, hparams)
  File "/scratch/abhinav/video_prediction/video_prediction/models/savp_model.py", line 693, in generator_given_z_fn
    cell = SAVPCell(inputs, mode, hparams)
  File "/scratch/abhinav/video_prediction/video_prediction/models/savp_model.py", line 311, in __init__
    ground_truth_sampling = tf.constant(False, dtype=tf.bool, shape=ground_truth_sampling_shape)

  File "/home/luke.skywalker/anaconda3/envs/savp2/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 196, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/home/luke.skywalker/anaconda3/envs/savp2/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 491, in make_tensor_proto
    (shape_size, nparray.size))
ValueError: Too many elements provided. Needed at most -9, but received 1

I think it's because I'm not setting the batch_size and sequence_length parameters properly. When I increase the sequence_length from 2 to 3, I get:

ValueError: Too many elements provided. Needed at most -8, but received 1

I feel I may have to increase the dataset size, but, is it possible for it to work on this one itself? Could you please help me out and advise me on how to fix this?

Thank you, Abhinav