Araachie / river

Efficient Video Prediction via Sparsely Conditioned Flow Matching. In ICCV, 2023.
https://araachie.github.io/river
GNU General Public License v3.0
8 stars 1 forks source link

About - initial_images.cuda(), # of shape [b n c h w] #2

Open ButoneDream opened 9 months ago

ButoneDream commented 9 months ago

hello , could u tell more detail about ‘initial_images.cuda()’ ?

Araachie commented 9 months ago

Hi,

By default the model is trained to perform video prediction task. This is the task of predicting k future frames given n initial ones. initial_images is a batch of b initial sequences of length n each that you want to make predictions from. For instance, you may be given some initial frames in which 2 objects from the CLEVRER dataset are approaching each other. The model is supposed to predict their future motion, e.g. how they collide and move in the future frames. You may use the provided VideoDataset class to sample the initial images from the test set of the corresponding dataset. Notice, that the model works only for in-distribution data, i.e. the initial_images should capture the same kind of data the model was trained on.

ButoneDream commented 9 months ago

hi !

`from lutils.configuration import Configuration from lutils.logging import to_video from model import Model from dataset import VideoDataset import torch config = Configuration('configs/kth.yaml') model = Model(config["model"]) model.load_from_ckpt('ckpt/kth_64.pth') model.cuda() model.eval() test_dataset = VideoDataset( data_path='out_dir/test/shard_0001.hdf5', # Replace with the actual path to your dataset input_size=256, # Replace with the desired input size crop_size=64, # Replace with the desired crop size frames_per_sample=5, skip_frames=0, random_time=True, random_horizontal_flip=True, aug=True, albumentations=False, total_videos=-1 # Set to the number of videos you want to include (-1 for all) ) num_sequences = 1 initialimages = [] for in range(num_sequences): initial_sequence = test_dataset[torch.randint(len(test_dataset), size=(1,)).item()] initial_images.append(initial_sequence.unsqueeze(0)) initial_images = torch.cat(initial_images, dim=0)

print("Shape of initial_images:", initial_images.shape) generated_frames = model.generate_frames( initial_images.cuda(), # of shape [b n c h w] num_frames=1, verbose=True)

generated_frames = to_video(generated_frames) import matplotlib.pyplot as plt import os

output_dir = 'generated_frames' os.makedirs(output_dir, exist_ok=True)

for i in range(generated_frames.shape[0]): # Assuming batch_size is the first dimension for j in range(generated_frames.shape[1]): # Assuming num_frames is the second dimension print(f"Processing frame {j + 1} of sequence {i + 1}") frame = generated_frames[i, j].transpose((1, 2, 0)) # Adjust dimensions for visualization print(f"Frame shape: {frame.shape}")

    # Visualize the frame
    plt.imshow(frame.astype('uint8'))
    plt.title(f'Generated Frame {j+1}')
    plt.axis('off')

    # Save the frame
    save_path = os.path.join(output_dir, f'generated_frame_{i}_frame_{j}.png')
    plt.savefig(save_path)
    plt.close()

print(f"Generated frames saved in {output_dir}")`

and the output in console:

Working with z of shape (1, 4, 8, 8) = 256 dimensions. Restored from ckpt/vqvae.ckpt Checking shard_lengths in ['out_dir/test/shard_0001.hdf5'] h5: Opening out_dir/test/shard_0001.hdf5... h5: paths 1 ; shard_lengths [36] ; total 36 Dataset length: 36 Shape of initial_images: torch.Size([1, 5, 3, 64, 64]) Processing frame 1 of sequence 1
Frame shape: (64, 64, 3) Processing frame 2 of sequence 1 Frame shape: (64, 64, 3) Processing frame 3 of sequence 1 Frame shape: (64, 64, 3) Processing frame 4 of sequence 1 Frame shape: (64, 64, 3) Processing frame 5 of sequence 1 Frame shape: (64, 64, 3) Processing frame 6 of sequence 1 Frame shape: (64, 64, 3) Generated frames saved in generated_frames lastly , the img like : image

Can you help me figure out what the problem is? Thank you very much.

Araachie commented 9 months ago

For the KTH dataset I would suggest using the arguments specified in the example config. I.e. setting random_horizontal_flip=False, aug=False, albumentations=True. KTH is a small dataset, so the autoencoder can be quite sensitive even to different interpolation techniques.