yuanzhoulvpi2017 / zero_nlp

中文nlp解决方案(大模型、数据、模型、训练、推理)
MIT License
3.03k stars 368 forks source link

train_llava 关于数据集构建的问题 #192

Closed weiaicunzai closed 1 week ago

weiaicunzai commented 3 weeks ago

请教一下,为啥构造数据的时候,要把question和answer concat到一起呢?这样输入进网络的话,网络不就知道了应该输出什么吗?感谢。

        input_ids = torch.concat(
            [
                q_input_ids,
                a_input_ids,
                torch.tensor(self.processor.tokenizer.eos_token_id).reshape(1, -1),
            ],
            axis=1,
        )

来自: https://github.com/yuanzhoulvpi2017/zero_nlp/blob/main/train_llava/train_llava/data.py#L108

weiaicunzai commented 1 week ago

还是next token prediction的方式

yuanzhoulvpi2017 commented 1 week ago

哈哈哈哈,看来你懂了奥~这玩意和sft一样的