PKU-YuanGroup / Video-LLaVA

Video-LLaVA: Learning United Visual Representation by Alignment Before Projection
https://arxiv.org/pdf/2311.10122.pdf
Apache License 2.0
2.7k stars 191 forks source link

Video-LLaVa now available in the Transformers library! #156

Open zucchini-nlp opened 2 months ago

zucchini-nlp commented 2 months ago

Hey!

Video-LLaVa is now available in the Transformers library! Feel free to check it out here. Thanks to @LinB203 for helping to ship the model 🤗

To get the model, update transformers by running: !pip install --upgrade git+https://github.com/huggingface/transformers.git. Inference with videos can be done as follows:

import av
import numpy as n
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")

prompt = "USER: <video>Why is this video funny? ASSISTANT:"
video_path = "YOUR-LOCAL-VIDEO-PATH
container = av.open(video_path)

# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)

inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=80)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
>>> 'USER:  Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.'

Check out:

LinB203 commented 2 months ago

It's a great feat. Thank you for your generous help!

rhelck commented 2 months ago

@zucchini-nlp I'm seeing the following problem

File "/home/rhelck/videotest.py", line 3, in from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration ImportError: cannot import name 'VideoLlavaProcessor' from 'transformers' (/home/rhelck/videovenv/lib/python3.10/site-packages/transformers/init.py)

The older example works fine for me, though. I reinstalled transfomers in a new venv for this by the way

zucchini-nlp commented 2 months ago

@rhelck hey! Did you install transformers from main as follows? Video-LLaVa will be included in the next release, which I believe will be in a few days. For now you can get it from main 🤗

!pip install --upgrade git+https://github.com/huggingface/transformers.git

darshana1406 commented 2 months ago

@zucchini-nlp I want to distribute the model on multiple gpus.

raise ValueError( ValueError: VideoLlavaForConditionalGeneration does not support device_map='auto'. To implement support, the model class needs to implement the _no_split_modules attribute.

zucchini-nlp commented 2 months ago

@darshana1406 could you open this as issue in transformers and tag me there, and I will add the "device_map" support roughly by the end of this week

Also you are welcome to open a PR, if you think you are willing to, we are always happy for community contributions 🤗

rhelck commented 2 months ago

@zucchini-nlp That worked perfectly, thanks!

IsabelJimenez99 commented 2 months ago

Can it also be used with images as before or only for videos?

zucchini-nlp commented 2 months ago

@IsabelJimenez99 , yes, the model can be used with images / videos / mix of image and video. Check out a colab notebook for inference examples with different input modalities

IsabelJimenez99 commented 2 months ago

Ah, ok. Sorry, I hadn't seen the collab. Thank you very much and excellent work. Congratulations!

BalloutAI commented 1 month ago

Can we use this library for fine-tuning as well or only for inference? If we can, is there documentation on how to use it properly? thanks!

zucchini-nlp commented 1 month ago

@BalloutAI Yes, we can. I am preparing a tutorial notebook for fine-tuning and will add it here, when it's done

BalloutAI commented 1 month ago

Thank you so much! Any expected timeline for that?

zucchini-nlp commented 1 month ago

@BalloutAI I made a short notebook for finetuning on a small dataset, you can find it here

IsabelJimenez99 commented 1 month ago

I am testing with the model ‘LanguageBind/Video-LLaVA-7B-hf’ and every time I run it on an image I get a different answer. I would like to know how much confidence the model has in each response, could I know?

zucchini-nlp commented 1 month ago

@IsabelJimenez99 You mean the model gives different generation every time, even if you keep the same image and prompt? That shouldn't be the case, can you share a minimal reproducible code?

Regarding the model's confidence in each response, have a look at this thread which shows how to get probability of each generated token :)

IsabelJimenez99 commented 1 month ago

Yes, it's the same image, same prompt but different answers. The code I used is the same as the one shown in your collab.

This is the code:

import` torch
from` transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration, BitsAndBytesConfig
import requests
from PIL import Image

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

url = "../frames_testeo/00006.jpg"
image = Image.open(url)

model_id = "LanguageBind/Video-LLaVA-7B-hf"
processor = VideoLlavaProcessor.from_pretrained(model_id)
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config)

# This time we will use a special "<image>" token instead of "<video>"
prompt = "USER: <image>\nWhich types of physical contact between people do you see in this image? Select all that you see from the following list: hand-hand, hand-shoulder, hand-elbow, hand-torso, elbow-shoulder, shoulder-shoulder, or none if there is no contact. Note: physical contact means that the mentioned body parts of different people are directly touching each other, not objects. ASSISTANT:"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)

# Generate
generate_kwargs = {"max_new_tokens":100, "do_sample":True, "top_p":0.9, "top_k":2}
generate_ids = model.generate(**inputs, **generate_kwargs)
generated_text = processor.batch_decode(generate_ids, skip_special_tokens=True)

print(generated_text[0])

On the other hand, I have tested what has happened to me and they propose the following: outputs = model.generate(inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)

However, I extrapolate that to their code and I get the following error: generate_kwargs = {"max_new_tokens":100, "do_sample":True, "top_p":0.9, "top_k":2} outputs = model.generate(inputs, generate_kwargs, output_scores=True) transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True )

AttributeError: 'Tensor' object has no attribute 'sequences'

zucchini-nlp commented 1 month ago

@IsabelJimenez99 Ah I see now, the different outputs each time is expected in this case because you have set do_sample=True which samples the next token randomly from logits dustribution, instead of getting the most likely token. To get a deterministic output, please use generate_kwargs = {"max_new_tokens":100} only.

And for the second issue, you need to set "return_dict_in_generate=True, output_scores=True" in the generate kwargs to get scores in the output. Otherwise we only return the generated text. For more details of which arguments you can pass in kwargs and what they mean, see the docs 🤗

IsabelJimenez99 commented 1 month ago

Oh! I understand now, thank you very much! And sorry for the inconvenience

orrzohar commented 1 month ago

@zucchini-nlp Does this support batch inferencing for faster evaluations?

zucchini-nlp commented 1 month ago

@orrzohar yes, the model supports batching. For that you just have to pass the prompts as a list of strings, and also the list of visuals. Also you can do batching with different visual inputs: for ex one prompt has only image and another had only video

prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<image>USER: What do you see in the image? ASSISTANT:", "<video>USER: more video instructions..."],
inputs = processor(text=prompts image=image, video=[clip, clip_2], return_tensors="pt")
n2nco commented 1 month ago

@orrzohar yes, the model supports batching. For that you just have to pass the prompts as a list of strings, and also the list of visuals. Also you can do batching with different visual inputs: for ex one prompt has only image and another had only video

prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<image>USER: What do you see in the image? ASSISTANT:", "<video>USER: more video instructions..."],
inputs = processor(text=prompts image=image, video=[clip, clip_2], return_tensors="pt")
clip = read_video_pyav(container, indices)
prompts = ["<video>USER: What do you see in the video? ASSISTANT:", "<video>USER: Describe the man in this video's clothing ASSISTANT:"]
inputs = processor(text=prompts, videos=[clip, clip], return_tensors="pt", padding=True, truncation=True)

How might one most efficiently batch multiple prompts with 1 single clip/video?

e.g. to achieve batched prompts applied to 1 single video

Passing in videos=[clip, clip] seems to ~double the inference time


btw in case it helps anyone reading:

i had to add padding & truncation args inputs = processor(text=prompts, videos=[clip, clip2], return_tensors="pt", padding=True, truncation=True)

zucchini-nlp commented 1 month ago

@n2nco in that case you have to pass the clip multiple times, as you have two separate prompts each with a special "video" token. Transformers cannot align one video for several clips, as we don't know for sure if that was an intention or a mistake in code, so the safe way is to pass in as many clips as there are special "video" tokens :)

WeizhenWang-1210 commented 1 month ago

Just a side note: could you move the fine-tuned notebook to the main page Markdown? It'll be much easier to spot. Much appreciated!

zucchini-nlp commented 1 month ago

@WeizhenWang-1210 hey! We don't usually add these notebooks in Transformers docs, but you can find this one and many more in our tutorials repo 🤗

BalloutAI commented 1 month ago

Hey, thanks for the awesome work. I am trying to use it almost as you are using it, but for some reason I am getting 100% accuracy even before training ( on sanity check I increased it to 20) which is impossible because I checked your demo and the performance was really bad before training. I was wondering if I am doing something wrong in my data handling.: `def read_video_pyav(video_path, start, end): """Reads a video for given start-end timestamps interval and uniformly samples 8 frames of it""" container = av.open(video_path) video = container.streams.get(0)[0]

av_timestamps = [
    int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
]

av_timestamps.sort()
start_id = bisect.bisect_left(av_timestamps, start)
end_id = bisect.bisect_left(av_timestamps, end)

# in case it is a very short video, lets take a longer duration and sample
if end_id  - start_id < 10:
    end_id += 10
    start_id -= 10

end_id = min(len(av_timestamps) - 1, end_id)
start_id = max(1, start_id)
indices = np.linspace(start_id, end_id, 8).astype(int)

frames = []
container.seek(0)
for i, frame in enumerate(container.decode(video=0)):
    if i > end_id:
        break
    if i >= start_id and i in indices:
        frames.append(frame)
assert len(frames) == 8, f"Got {len(frames)} frames but should be 8. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def collate_read_video(example, path): clip = read_video_pyav(example["video"], example.get("start", 1), example.get("end", 1e+10)) example["clip"] = clip return example

def load_videos_from_directory(directory): data = {"video": [], "label": []} for label in ["True", "False"]: folder = os.path.join(directory, label) for filename in os.listdir(folder): if filename.endswith(".mp4"): # data["video"].append(os.path.join(folder, filename)) data["label"].append(1 if label == "True" else 0) return data

data = load_videos_from_directory("/mypath") hf_dataset = HFDataset.from_dict(data) dataset = hf_dataset.train_test_split(test_size=0.2)

dataset = dataset.map(collate_read_video, batched=False, fn_kwargs={"path": ""}, writer_batch_size= 100)

processor = AutoProcessor.from_pretrained(MODEL_ID) processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right

class VideoLlavaDataset(Dataset): """ PyTorch Dataset for VideoLlavaDataset. This class takes a HuggingFace Dataset as input. """

def __init__(
    self,
    dataset: HFDataset,
):
    super().__init__()
    self.dataset = dataset

def __len__(self) -> int:
    return len(self.dataset)

def __getitem__(self, idx: int):
    sample = self.dataset[idx]
    clip = np.array(sample["clip"])

    label = sample["label"]
    label_text = "True" if label == 1 else "False"
    mult_choice = "True or False"

    prompt = f"USER: <video>\nAnswer the following question based on the video by {mult_choice}. " \
             f"ASSISTANT: Answer: {label_text}"

    return prompt, clip

def train_collate_fn(examples): videos = [] texts = [] texts, videos = list(zip(*examples))

batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
pixel_values_videos = batch["pixel_values_videos"]
labels = batch["labels"]

return input_ids, attention_mask, pixel_values_videos, labels

def eval_collate_fn(examples): videos = [] texts = [] texts, videos = list(zip(*examples)) texts = [text for text in texts]

batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
pixel_values_videos = batch["pixel_values_videos"]
answer_choice = [text.split("Answer: ")[-1] for text in texts]  # Extract the answer text
return input_ids, attention_mask, pixel_values_videos, answer_choice

train_dataset = VideoLlavaDataset(dataset["train"]) eval_dataset = VideoLlavaDataset(dataset["test"])

class VideoLlavaModelPLModule(L.LightningModule): def init(self, config, processor, model): super().init() self.config = config self.processor = processor self.model = model

    self.batch_size = config.get("batch_size")
    self.predictions = []
    self.answers = []

def training_step(self, batch, batch_idx):

    input_ids, attention_mask, pixel_values_videos, labels = batch

    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values_videos=pixel_values_videos,
        labels=labels
    )
    loss = outputs.loss

    self.log("train_loss", loss)

    return loss

def validation_step(self, batch, batch_idx, dataset_idx=0):
   input_ids, attention_mask, pixel_values_videos, answers = batch
# Autoregressively generate token IDs
   generated_ids = self.model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    pixel_values_videos=pixel_values_videos,
    max_new_tokens=MAX_LENGTH,
    do_sample=False,
)

# Decode the generated token IDs into text, chopping off the prompt
   decoded_predictions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

# Extract the word after "Answer: "
   predictions = []
   for pred in decoded_predictions:
       if "Answer:" in pred:
        answer_part = pred.split("Answer:")[-1].strip()
        predictions.append(answer_part.split()[0])  # Get the first word after "Answer:"
       else:
        predictions.append("")  # Handle cases where "Answer:" is not found

   correct = 0
   for pred, answer in zip(predictions, answers):
      normalized_pred = pred.strip().lower()
      print(normalized_pred)
      normalized_answer = answer.strip().lower()
      print(normalized_answer)
      correct += (normalized_pred == normalized_answer)

   accuracy = correct / len(answers)

# Store the predictions and answers for epoch-end processing
   self.predictions.extend(predictions)
   self.answers.extend(answers)

   return correct

def on_validation_epoch_end(self):
   correct = sum([pred.strip().lower() == ans.strip().lower() for pred, ans in zip(self.predictions, self.answers)])
   accuracy = correct / len(self.answers)
   print(len(self.answers))

   print(f"on_Validation Accuracy: {accuracy * 100:.2f}%")

`

n2nco commented 1 month ago

think it'd be straight forward to swap the vicuna-7b for a llama-3-8b base? e.g. https://huggingface.co/lmms-lab/llama3-llava-next-8b

zucchini-nlp commented 1 month ago

@BalloutAI , i am not sure where is the "question" that you're referring to in the prompt, and it's weird that the models is getting 100%. Did you try verifying the validation dataloader is correct (shapes and content), and turning on verbose mode to print the prediction/answers?

@n2nco yes, swapping the backbone LLM should be easy by tweaking with the model's config, but the new model would require training. AFAIK the llava-Next you're pointing to can do video generation even if it wasn't trained for that. We're working on adding those in transformers 😄

BalloutAI commented 1 month ago

Yeah, I have tried printing, and it is getting them correctly ['USER: \nAnswer the following question based on the video by True or False. ASSISTANT: Answer: True']. and it is answering them correctly no matter what the question is for some reason. My guess was that I am feeding the answers to the model directly somehow, but I cant find the problem, because I am getting my answer from the decoded_predictions.

zucchini-nlp commented 1 month ago

@BalloutAI Ah, sorry, you're right! Didn't see you had a different way of collate_fn. In the eval_collate when you feed the text to tokenizer, you have to get rid of the answer first.


texts = [text.split("Answer: ")[-1] for text in texts]  # Extract text w/o answer
batch = processor(text=texts, videos=videos, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
BalloutAI commented 1 month ago

Awesome, thx! I expected that!

caichuang0415 commented 1 month ago

Hey!

Video-LLaVa is now available in the Transformers library! Feel free to check it out here. Thanks to @LinB203 for helping to ship the model 🤗

To get the model, update transformers by running: !pip install --upgrade git+https://github.com/huggingface/transformers.git. Inference with videos can be done as follows:

import av
import numpy as n
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")

prompt = "USER: <video>Why is this video funny? ASSISTANT:"
video_path = "YOUR-LOCAL-VIDEO-PATH
container = av.open(video_path)

# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)

inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=80)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
>>> 'USER:  Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.'

Check out:

Thanks for your contribution. But I came across a bug: ValueError: Video pixel values should have exactly 8 frames but foung 24. I try to sample 24 frames from a video and it shows this bug. Does it only support sample 8 frames from a video? how can we put more frames into it or put the whole video?

zucchini-nlp commented 1 month ago

@caichuang0415 hey! Yes, since VIdeoLlava was trained with 8 frames, we currently support only 8-frame videos. You can open a PR if you want to give it a chance, otherwise I'll take a look at it next week :)

zucchini-nlp commented 1 month ago

@caichuang0415 now Video-LLaVa can work with any number of frames at input, But note that inference with more than 8 frames degrades quality, as the model wasn't trained in that setting. I recommend to tune with 24 frames first, if you want good performance.

To get the updated version, please update transformers with: !pip install --upgrade git+https://github.com/huggingface/transformers.git

caichuang0415 commented 1 month ago

@caichuang0415 now Video-LLaVa can work with any number of frames at input, But note that inference with more than 8 frames degrades quality, as the model wasn't trained in that setting. I recommend to tune with 24 frames first, if you want good performance.

To get the updated version, please update transformers with: !pip install --upgrade git+https://github.com/huggingface/transformers.git

thanks for your updating! I will take your advise and make more experiments

zucchini-nlp commented 3 weeks ago

@s-s-la which notebook you're using? The one I linked above leads to VideoLlava and works in 4.42.

The error message mentions another model which I'll merge into transformers on Monday and post about it in LlavaNext repo ;)

sherlock666 commented 2 weeks ago

@zucchini-nlp how to finetune with more sample frames ? the comment for video llava finetune said:

We sample 8 frames for tuning following the original paper

But we can increase the number of frames for longer videos and check out if it helps performance

Change the below "8" to any number of frames you want, and note that more frames -> more computational resources needed

indices = np.linspace(start_id, end_id, 8).astype(int) However after i set to 30 and finetune it show:

Traceback (most recent call last): File "videollava_finetune_original_100.py", line 505, in trainer.fit(model_module) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit call._call_and_handle_interrupt( File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run results = self._run_stage() File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage self._run_sanity_check() File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check val_loop.run() File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator return loop_run(self, *args, kwargs) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step output = call._call_strategy_hook(trainer, hook_name, step_args) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook output = fn(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step return self.lightning_module.validation_step(*args, kwargs) File "videollava_finetune_original_100.py", line 435, in validation_step generated_ids = self.model.generate( File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 647, in generate return self.get_base_model().generate(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 1758, in generate result = self._sample( File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2397, in _sample outputs = self( File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/transformers/models/video_llava/modeling_video_llava.py", line 513, in forward image_outputs, video_outputs = self._get_vision_features( File "/usr/local/lib/python3.8/dist-packages/transformers/models/video_llava/modeling_video_llava.py", line 377, in _get_vision_features raise ValueError(f"Video pixel values should have exactly 8 frames but foung {num_frames}") ValueError: Video pixel values should have exactly 8 frames but foung 30

Does it mean that if i really want to change 8 to 30 I need to fully train model again? if so.....i suggest the comment should be deleted which is confusing.....

Also another question is that

if i set about more then 50 frame, it'll cause error :

OverflowError: There was an overflow with type <class 'list'>. Try to reduce writer_batch_size to have batches smaller than 2GB. (offset overflow while concatenating arrays)

How can i solve it if i really want to use?

thanks!!!

zucchini-nlp commented 2 weeks ago

@sherlock666 can you update your transformers version and install from main with !pip install --upgrade git+https://github.com/huggingface/transformers.git ?

sherlock666 commented 2 weeks ago

@sherlock666 can you update your transformers version and install from main with !pip install --upgrade git+https://github.com/huggingface/transformers.git ?

Thanks for reply

So do you mean that the latest transformer actually won't cause those two problems?

zucchini-nlp commented 2 weeks ago

It will solve the first problem. The second can be solved by decreasing writer_batch_size as the error msg says. The default is 1000 afaik.

The issue is that when you get more frames and if your videos are high-resolution, you'll end up with a memory-consuming batches. I had similar problem with another model (at 8 frames). You can also consider doing collate and "read_video" in one dataset.map() so that we don't have to "write" the unprocessed video. In that case each video will have a fixed 336x336 size and that will lower your memory consumption per batch.

Hope it's clear :)

sherlock666 commented 2 weeks ago

It will solve the first problem. The second can be solved by decreasing writer_batch_size as the error msg says. The default is 1000 afaik.

The issue is that when you get more frames and if your videos are high-resolution, you'll end up with a memory-consuming batches. I had similar problem with another model (at 8 frames). You can also consider doing collate and "read_video" in one dataset.map() so that we don't have to "write" the unprocessed video. In that case each video will have a fixed 336x336 size and that will lower your memory consumption per batch.

Hope it's clear :)

I just check , i'm using docker with latest transformer version (4.41.2) Or....Should I change to certain old version? ( using Finetune code

Thanks for help

update1: Kindly remind the error leads to here the "8" was fixed

https://github.com/huggingface/transformers/blob/dc76e9fa7f0d19ff7cfc33bd3a22acd7df167fce/src/transformers/models/video_llava/modeling_video_llava.py#L377

zucchini-nlp commented 2 weeks ago

Sorry if I wasn't clear, I meant updating to the version from main since the latest release is planned for today and is not out yet. The cli command above should update version to the main branch :)