Open dongzhuoyao opened 2 months ago
Met the same problem. Solved. Go to omnitokenizer.py, correct the strokes:
self.encoder = OmniTokenizer_Encoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)
self.decoder = OmniTokenizer_Decoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)
As you can see the authors made a typo in the word causal, that's why the attribute 'causal_...' does not exist
Here's the code I use to encode and decode videos, once the errors mentioned above are corrected:
from OmniTokenizer import OmniTokenizer_VQGAN
import torch
from torchvision.io import write_video
import numpy as np
from decord import VideoReader, cpu
from einops import rearrange
device = 'cuda:0'
dtype=torch.bfloat16
vqgan = OmniTokenizer_VQGAN.load_from_checkpoint('imagenet_k600.ckpt', strict=False)
vqgan.requires_grad_(False)
vqgan.eval()
vqgan = vqgan.to(device, dtype=dtype)
video_reader = VideoReader('input.mp4', ctx=cpu(0))
fps = video_reader.get_avg_fps()
video = video_reader.get_batch(list(range(len(video_reader)))).asnumpy()
video = torch.from_numpy(video).to(dtype)
video = rearrange(video[:-3], 't h w c -> 1 c t h w') # skip last couple frames to avoid /4 errors, will change based on input frame count
video = video / 255 # (0.0-1.0)
video = video.to(device=device, dtype=dtype)
video = video - 0.5 # (-0.5-0.5)
video = video.clamp(-0.5, 0.5)
with torch.no_grad():
tokens = vqgan.encode(video, is_image=False)
print(tokens.shape)
recons = vqgan.decode(tokens, is_image=False)
video_dec = recons.clamp(-0.5, 0.5)
video_dec = rearrange(video_dec.squeeze(0), 'c t h w -> t h w c') # format for output
video_dec = (video_dec + 0.5).clamp(0, 1)
video_dec = video_dec.cpu().float().numpy()
video_dec = (video_dec * 255).astype(np.uint8) # Convert to 0-255 and write
write_video("recon.mp4", video_dec, fps=fps, options={'crf': '0'})
File "/export/scratch/ra63nev/lab/discretediffusion/OmniTokenizer/omnitokenizer.py", line 108, in init spatial_depth=args.spatial_depth, temporal_depth=args.temporal_depth, causal_in_temporal_transformer=args.causal_in_temporal_transformer, causal_in_peg=args.causal_in_peg, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'Namespace' object has no attribute 'causal_in_temporal_transformer'. Did you mean: 'casual_in_temporal_transformer'?
I tried two ckpts, all doesn;t work.
vqgan_ckpt = "./pretrained_ckpt/imagenet_k600.ckpt"