ssbuild / chatglm2_finetuning

chatglm2 6b finetuning and alpaca finetuning
Apache License 2.0
144 stars 17 forks source link

纯请教 关于if sup #10

Open magnificent1208 opened 1 year ago

magnificent1208 commented 1 year ago
def process(cls, tokenizer: ChatGLMTokenizer,config, a_ids, b_ids, max_seq_length, sptoken: typing.List,ensure_answer_min_length=1,sup=True):
        input_ids = a_ids[:max_seq_length - len(b_ids) -3 - ensure_answer_min_length] + b_ids
        a_len = len(input_ids) - len(b_ids)
        input_ids = input_ids[:max_seq_length - 3] + [config.eos_token_id]
        if sup:
            labels = [-100] * a_len + input_ids[a_len:]
        else:
            labels = copy.deepcopy(input_ids)
        input_ids = sptoken + input_ids
        labels = [-100] * len(sptoken) + labels

        d = TokenIdsFinal.process(input_ids,labels,max_seq_length,tokenizer)
        return [d]

这个if sup 的含义具体是什么?是pretrian和sft的区别对么?

ssbuild commented 1 year ago

可以这么理解。

magnificent1208 commented 1 year ago

感谢解答

magnificent1208 commented 1 year ago

看完您的代码,我的理解是: pretrain跟finetune其实损失函数计算是一致的,只是构建label的时候不一样

def process(cls, tokenizer: ChatGLMTokenizer,config, a_ids, b_ids, max_seq_length, sptoken: typing.List,sliding_size = None,sup=True): if sliding_size is None: sliding_size = max_seq_length ds = [] input_ids_qa = a_ids + b_ids + [config.eos_token_id] if sup: labels_all = [-100] * len(a_ids) + b_ids else: labels_all = copy.deepcopy(input_ids_qa) a_ids :问题q b_ids :答案a [config.eos_token_id]:结束token标志

如果监督学习,则label是 遮挡了q部分的内容 如果是非监督学习(pretrain),则label是q+a,完整的问答对。

这样的理解是正确的吗?