rese1f / MovieChat

[CVPR 2024] MovieChat: From Dense Token to Sparse Memory for Long Video Understanding
https://rese1f.github.io/MovieChat/
BSD 3-Clause "New" or "Revised" License
534 stars 41 forks source link

Breakpoint mode code questions? #18

Closed tian1327 closed 1 year ago

tian1327 commented 1 year ago

Hi, thank you for the great work! I would appreciate if you could address some of my questions. For the breakpoint mode, from the code below,

  1. The num_frames is the number of segments before the current second, and the cur_frame is the frame index in the current segment. Do I understand it correctly?
  2. It seems that only the video fragments before the queried time are considered in the long-term memory. Is this true? If so, should it actually incorporate all fragments in the video?
  3. Why when doing self.model.encode_short_memory_frame(video_fragment, cur_frame) in line 271, we only consider the first cur_frame frames of each segment? Should it only applies to the last segment? In other words, for the segments before the last segment, should we adopt line 273 which considers all frames in the segment?

image

tian1327 commented 1 year ago

Also, the code below is a bug which could cause infinite loop because poping does not affect cur_short_length value. Would you pls fix it?

image

Espere-1119-Song commented 1 year ago
  1. Your understand of num_frames and cur_frame is correct.
  2. Only the video fragments before the queried time are considered in the long-term memory, and we just incorporate fragments before the queried time.
  3. Sorry, there is something wrong with the judgment condition on line 270. It should actually be if middle_video and (i+1)==num_frames:. That means that when the last fragment is read, it stops when cur_frame is read, instead of stopping when every fragment reads cur_frame. I will update this question in the next version.
  4. Thanks for the reminder, I've noticed this issue. The complete code is still being compiled. You can modify the code on lines 332 to 333 as follows:
    while (len(self.long_memory_buffer)+len(self.temp_short_memory)+1) > frame_position_embeddings.shape[0]:
    if len(self.temp_short_memory) != 0:
        self.temp_short_memory.pop(0)
    else:
        self.long_memory_buffer.pop(0)