lucidrains / phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
MIT License
748 stars 79 forks source link

Unconditional Training returns errors #26

Closed 1geek0 closed 1 year ago

1geek0 commented 1 year ago

I'm trying to train an unconditional model with image and gif data I have, to have coherent video generated from gifs of manga panels:

cvivit = CViViT(
    dim = 512,
    codebook_size = 5000,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    unconditional = True # Kept this true, otherwise it asks for text samples (I only have image data)
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

trainer = PhenakiTrainer(
    phenaki=phenaki,
    batch_size=4,
    grad_accum_every=4,
    train_on_images=True,
    folder='../dataset/compressed_manga/'
)

trainer.train()

When training, the following error is raised: sample_images() got an unexpected keyword argument 'num_frames'

image

I think the arg num_frames is being passed to the method sample_images. Can someone confirm this is a bug? I'll submit a PR with a fix

lucidrains commented 1 year ago

@1geek0 hi Nilay! thanks for testing it out! this should be fixed in 0.0.66

right now the trainer does not support training on both gifs and images by the way, you'll have to train on an image folder separately, before upgrading to video training

1geek0 commented 1 year ago

Thanks for replying @lucidrains . I am trying to train on jpegs 256x256. No gifs

lucidrains commented 1 year ago

@1geek0 ok nice! make sure your cvivit is trained first before you attempt full phenaki!

1geek0 commented 1 year ago

Ah! Missed that. Now training C-ViViT. Will try with phenaki afterwards

1geek0 commented 1 year ago

We can close this. I've fixed this error on my local repo. Happy to raise a PR