FoundationVision / OmniTokenizer

[NeurIPS 2024]OmniTokenizer: one model and one weight for image-video joint tokenization.
https://www.wangjunke.info/OmniTokenizer/
MIT License
264 stars 7 forks source link

error when initializing the OmniTokenizer #20

Open dongzhuoyao opened 2 months ago

dongzhuoyao commented 2 months ago

File "/export/scratch/ra63nev/lab/discretediffusion/OmniTokenizer/omnitokenizer.py", line 108, in init spatial_depth=args.spatial_depth, temporal_depth=args.temporal_depth, causal_in_temporal_transformer=args.causal_in_temporal_transformer, causal_in_peg=args.causal_in_peg, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'Namespace' object has no attribute 'causal_in_temporal_transformer'. Did you mean: 'casual_in_temporal_transformer'?

I tried two ckpts, all doesn;t work.

vqgan_ckpt = "./pretrained_ckpt/imagenet_k600.ckpt"

vqgan_ckpt = "./pretrained_ckpt/imagenet_ucf.ckpt"

vqgan_omni = OmniTokenizer_VQGAN.load_from_checkpoint(vqgan_ckpt, strict=False)

omni_tokenizer = vqgan_omni.to(device)
image = load_and_preprocess_image(img_path)
image = image.to(device)
indices = omni_tokenizer.encode(image)
print(
    f"image {img_path} is encoded into tokens {indices}, with shape {indices.shape}"
)
# de-tokenization
reconstructed_image = omni_tokenizer.decode(indices)
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
reconstructed_image = (
    (reconstructed_image * 255.0)
    .permute(0, 2, 3, 1)
    .to("cpu", dtype=torch.uint8)
    .numpy()[0]
)
Image.fromarray(reconstructed_image).save("reconstructed_image_omni.png")
LisIva commented 1 month ago

Met the same problem. Solved. Go to omnitokenizer.py, correct the strokes:

self.encoder = OmniTokenizer_Encoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)

self.decoder = OmniTokenizer_Decoder(...causal_in_temporal_transformer=args.casual_in_temporal_transformer, causal_in_peg=args.casual_in_peg, ...)

As you can see the authors made a typo in the word causal, that's why the attribute 'causal_...' does not exist

NilanEkanayake commented 1 month ago

Here's the code I use to encode and decode videos, once the errors mentioned above are corrected:

from OmniTokenizer import OmniTokenizer_VQGAN
import torch
from torchvision.io import write_video
import numpy as np
from decord import VideoReader, cpu
from einops import rearrange

device = 'cuda:0'
dtype=torch.bfloat16
vqgan = OmniTokenizer_VQGAN.load_from_checkpoint('imagenet_k600.ckpt', strict=False)
vqgan.requires_grad_(False)
vqgan.eval()
vqgan = vqgan.to(device, dtype=dtype)

video_reader = VideoReader('input.mp4', ctx=cpu(0))
fps = video_reader.get_avg_fps()

video = video_reader.get_batch(list(range(len(video_reader)))).asnumpy()
video = torch.from_numpy(video).to(dtype)
video = rearrange(video[:-3], 't h w c -> 1 c t h w') # skip last couple frames to avoid /4 errors, will change based on input frame count

video = video / 255 # (0.0-1.0)

video = video.to(device=device, dtype=dtype)

video = video - 0.5 # (-0.5-0.5)
video = video.clamp(-0.5, 0.5)

with torch.no_grad():
    tokens = vqgan.encode(video, is_image=False)
    print(tokens.shape)
    recons = vqgan.decode(tokens, is_image=False)

video_dec = recons.clamp(-0.5, 0.5)

video_dec = rearrange(video_dec.squeeze(0), 'c t h w -> t h w c') # format for output
video_dec = (video_dec + 0.5).clamp(0, 1)
video_dec = video_dec.cpu().float().numpy()

video_dec = (video_dec * 255).astype(np.uint8)  # Convert to 0-255 and write
write_video("recon.mp4", video_dec, fps=fps, options={'crf': '0'})