coxlab / prednet

Code and models accompanying "Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"
https://arxiv.org/abs/1605.08104
MIT License
759 stars 259 forks source link

How can I use the t+5 future frame prediction without ground truth? #60

Closed ronykalfarisi closed 4 years ago

ronykalfarisi commented 5 years ago

Hi @bill-lotter , Thank you for creating and maintaining this awesome repository. I have a question regarding t+5 future frame prediction. Let's say I have 20 images in a folder and I successfully created the hickle file for the data and the sources. I want to predict the next 5 frames using the fine-tune model weights with these 20 images. How can I achieve this? I used nt=25 and extrap_start_time=20 but I got an error. Thanks

bill-lotter commented 5 years ago

Hi, what did the error say? Did you try doing something similar to what's in kitti_extrap_finetune.py?

ronykalfarisi commented 5 years ago

@bill-lotter thank you so much for your quick reply. I have two questions regarding this issue.

  1. Yes, I follow the kitti_extrap_finetune.py and slightly modify the code. The following is the code that I used which is very similar to kitti_evaluate.py
    
    # CONSTANT VALUES
    n_plot = 40
    batch_size = 10
    nt = 25  
    extrap_start_time = 20   

Setting up path file

weights_file = os.path.join(WEIGHTS_DIR, 'tensorflow_weights/prednet_kitti_weights-extrapfinetuned.hdf5') json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model-extrapfinetuned.json')

mini_file = os.path.join(DATA_DIR, 'X_test_mini.hkl') <--- shape is (20, 128, 160, 3) mini_sources = os.path.join(DATA_DIR, 'sources_test_mini.hkl')

Load trained model

with open(json_file, 'r') as f: json_string = f.read() train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) train_model.load_weights(weights_file)

Create testing model (to output predictions)

layer_config = train_model.layers[1].get_config() layer_config['output_mode'] = 'prediction' layer_config['extrap_start_time'] = extrap_start_time data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']

test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)

input_shape = list(train_model.layers[0].batch_input_shape[1:]) input_shape[0] = nt inputs = Input(shape=tuple(input_shape)) predictions = test_prednet(inputs) test_model = Model(inputs=inputs, outputs=predictions)

test_generator = SequenceGenerator(mini_file, mini_sources, nt, sequence_start_mode='unique', data_format=data_format) X_test = test_generator.create_all() <--- (0, 25, 128, 160, 3) X_hat = test_model.predict(X_test, batch_size) <--- empty or [] if data_format == 'channels_first': X_test = np.transpose(X_test, (0, 1, 3, 4, 2)) X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))


I found what caused the error and that is **X_hat is empty**. This happened because X_test (from SequenceGenerator) has shape of (0, 25, 128, 160, 3) and I believe this is due to nt=25 while length of images is only 20. If I use nt=20 (with extrap_start_time still 20), X_test will have shape of (1, 20, 128, 160, 3) and the code will work as intended and X_hat will show result with shape (1, 20, 128, 160, 3) as well. However I still don't know how I can get the next 5 frame prediction using this code, and this bring me to my second question.

2. Looking back, it turned out I still don't know how to predict t+1 frame without ground truth. If I have only 10 frames in a directory and create this into a hickle file, I will have X_test with shape of (1, 10, im_row, im_col, chan) after creating the SequenceGenerator. If I pass this into the model.predict(), I will have X_hat with the same shape. From my understanding, the first frame in X_hat is all zero while the second frame is the prediction of the first frame from X_test, and the third frame of X_hat is the prediction of the second frame of X_test, and so on. Therefore, the 10th frame of X_hat is the prediction of the 9th frame of X_test and since nt=10, it stops there. My question is, how can I predict the 11th frame using X_test  (where I don't really care of any previous frame prediction)?

Thank you so much for your help
WalkInThePast commented 5 years ago

Hi,I also encountered the same problems. Have you solved it?

ronykalfarisi commented 5 years ago

@WalkInThePast , yeah I just created 5 empty frames into X_test during the data processing, and you feed this into the predict() method to get the X_hat.

Roy-Rupak commented 4 years ago

@ronykalfarisi While trying to run kitti_extrapfinetune.py, I got an error "ValueError: Error when checking input: expected input_1 to have shape (10, 3, 128, 160) but got array with shape (10, 128, 160, 3)". Did you face any such issue? Also, can you share the code that you used for training the multi-step prediction? Thanks.

ronykalfarisi commented 4 years ago
def process_data(settings, future_frame=5):
    tmpdir = 'tmp/data-ee'
    desired_im_sz = (settings.HEIGHT, settings.WIDTH)

    print("Preparing to process test data...")
    next_recordings = [('all', tmpdir)]
    splits = {s: [] for s in ['single']}
    splits['single'] = next_recordings
    for split in splits:
        im_list = []
        # corresponds to recording that image came from
        source_list = []
        for category, folder in splits[split]:
            _, _, files = next(os.walk(folder))
            im_list += [os.path.join(folder, f) for f in sorted(files)]
            source_list += [category + '-' + folder] * (len(files) + future_frame)

        print(f'Creating {split} data: {len(im_list)} images')
        X = np.zeros((len(im_list) + future_frame,) + desired_im_sz + (3,), np.uint8)
        for i, im_file in enumerate(im_list):
            # im = imread(im_file, mode='RGB')
            # X[i] = process_im(im, desired_im_sz)
            im = PIL.Image.open(im_file)
            im = im.resize(desired_im_sz, resample=PIL.Image.BICUBIC)
            X[i] = np.asarray(im)

        hickle.dump(X, os.path.join(settings.PROC_DIR, 'X_' + split + '.hkl'))
        hickle.dump(source_list, os.path.join(settings.PROC_DIR, 'sources_' + split + '.hkl'))
    print("Processing single test data finished")

def run(settings, extrapolation=True):
    process_data(settings)

    # CONSTANT VALUES
    n_plot = 40
    batch_size = 10
    nt = 11
    extrap_start_time = None

    # Setting up path file
    weights_file = os.path.join(settings.WEIGHTS_DIR, 'prednet_ee_weights.hdf5')
    json_file = os.path.join(settings.WEIGHTS_DIR, 'prednet_ee_model.json')

    # test_file = os.path.join(PROC_DIR, 'X_test.hkl')
    # test_sources = os.path.join(PROC_DIR, 'sources_test.hkl')
    # mini_file = os.path.join(PROC_DIR, 'X_test_mini.hkl')
    # mini_sources = os.path.join(PROC_DIR, 'sources_test_mini.hkl')
    # single_file = os.path.join(PROC_DIR, 'X_single.hkl')
    # single_sources = os.path.join(PROC_DIR, 'sources_single.hkl')

    if extrapolation:
        nt = 15
        extrap_start_time = 10
        weights_file = os.path.join(settings.WEIGHTS_DIR, 'prednet_ee_weights-extrapfinetuned.hdf5')
        json_file = os.path.join(settings.WEIGHTS_DIR, 'prednet_ee_model-extrapfinetuned.json')
        single_file = os.path.join(settings.PROC_DIR, 'X_single.hkl')
        single_sources = os.path.join(settings.PROC_DIR, 'sources_single.hkl')

    # Load trained model
    with open(json_file, 'r') as f:
        json_string = f.read()
    train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
    train_model.load_weights(weights_file)

    # Create testing model (to output predictions)
    layer_config = train_model.layers[1].get_config()
    layer_config['output_mode'] = 'prediction'
    layer_config['extrap_start_time'] = extrap_start_time
    data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
    test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)

    input_shape = list(train_model.layers[0].batch_input_shape[1:])
    input_shape[0] = nt
    inputs = Input(shape=tuple(input_shape))
    predictions = test_prednet(inputs)
    test_model = Model(inputs=inputs, outputs=predictions)

    # test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
    # test_generator = SequenceGenerator(mini_file, mini_sources, nt, sequence_start_mode='unique', data_format=data_format)
    test_generator = SequenceGenerator(single_file, single_sources, nt, sequence_start_mode='unique', data_format=data_format)
    X_test = test_generator.create_all()
    X_hat = test_model.predict(X_test, batch_size)
    if data_format == 'channels_first':
        X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
        X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))

    # Plot some predictions
    aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3]
    # plt.figure(figsize = (nt, 2*aspect_ratio))
    nt_plot = 5
    # pos = int(nt/nt_plot)

    plot_save_dir = os.path.join(settings.RESULTS_SAVE_DIR, 'prediction_plots/')
    if not os.path.exists(plot_save_dir): 
        os.makedirs(plot_save_dir)

    plot_idx = np.random.permutation(X_test.shape[0])[:n_plot]
    for i in plot_idx:
        fig = plt.figure(figsize=(4*nt, 8*aspect_ratio))
        gs = gridspec.GridSpec(2, nt, figure=fig)
        gs.update(wspace=0., hspace=0.)
        for t in range(nt_plot):
            # plt.subplot(gs[t])
            # print(f'X_test[{i,t}] shape : {X_test[i,t].shape} ')
            # plt.imshow(X_test[i, pos * (t + 1)], interpolation='none')
            fig.add_subplot(gs[0, t])
            plt.imshow(X_test[i, t + 10], interpolation='none')
            plt.tick_params(axis='both', which='both', bottom=False, top=False, 
                            left=False, right=False, labelbottom=False, labelleft=False)
            if t==0: 
                plt.ylabel('Actual', fontsize=10)

            # plt.subplot(gs[t + nt])
            # print(f'X_hat[{i,t}] shape : {X_hat[i,t].shape} ')
            fig.add_subplot(gs[1, t])
            plt.imshow(X_hat[i, t + 10], interpolation='none')
            img = X_hat[i, t + 10]*255

            # cv2.imshow('plot_' + str(t), img)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            # plt.imshow(X_hat[i, pos * (t + 1)], interpolation='none')
            plt.tick_params(axis='both', which='both', bottom=False, top=False, 
                            left=False, right=False, labelbottom=False, labelleft=False)
            if t==0: 
                plt.ylabel('Predicted', fontsize=10)

            # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # cv2.imwrite('plot_ee_' + str(t) + '.png', img)
        # plt.show()
        plt.savefig(plot_save_dir +  'plot_' + str(i) + '.png')
        # plt.clf()

Good luck...

Roy-Rupak commented 4 years ago

Thanks @ronykalfarisi