Open zeyun-zhong opened 2 days ago
Hello zeyun, thanks for your attention! This prediction loss is achieved by setting the learning target. The implementation follows:
chat = [
{'role': 'system', 'content': '...'},
{'role': 'user', 'content': 'cool?'},
{'role': 'stream', 'num_frames': 3, 'learn': 3},
{'role': 'assistant', 'content': 'cool.', 'learn': True},
]
If you use tokenizer.apply_chat_template to see the prompt, you will get
...
User: cool?
[<v><v><v><v><v><v><v><v><v><v>,<v><v><v><v><v><v><v><v><v><v>,<v><v><v><v><v><v><v><v><v><v>]
Assistant: cool.
Here each frame contains 10 tokens. As you can see, here we set {'role': 'stream', 'num_frames': 3, 'learn': 3}
, so we set every frame (the last vision token of the frame) will learn a target. If you have a look at the input_ids and labels after data_collator, you will get the last vision token of each frame, <v>
, learns the next token ,
or ]\n
.
,
token as frame_interval_token_id, since it represents the interval token between frames. It will be regarded as EOS token in inference, which is defined in If the next token is frame_interval_token_id, just continually input frame, otherwise do language generation.
https://github.com/showlab/videollm-online/blob/main/models/tokenization_live.py
and
https://github.com/showlab/videollm-online/blob/main/data/stream.py
and
https://github.com/showlab/videollm-online/blob/main/data/data_collator.py
In short, the tokenization_live.py defines the chat templates to a string paragraph and get the learning ranges, the stream.py calls that, and data_collator.py will map the learning ranges from string indices to input_ids by return_offsets_mapping
arguments in huggingface tokenizer. You can see the details here:
https://huggingface.co/docs/transformers/en/internal/tokenization_utils
Then during training, we can learn that "EOS" by simply calling CE loss on input_ids and labels.
As you can see, we get the labels in dataloader, which is model agnostic and easy to extend. Furthermore, we can naturally enjoy the speed benefits from multiple dataloader workers.
Thank you very much for the detailed explaination. Now, I understand your approach better.
I have another issue when running demo.inference:
TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_stream_prompt'
which is caused by line 33 in inference.py. I am wondering whether you have a custom tokenizer.apply_chat_template
function?
Thank you for the excellent project. I have a question regarding the streaming EOS prediction discussed in Section 3.1 of the paper. Could you please specify in which file this part is implemented?