Tele-AI / Telechat

1.76k stars 95 forks source link

星辰支持AutoModelForSequenceClassification任务相关问题 #71

Open tcoln opened 1 week ago

tcoln commented 1 week ago

尝试在星辰开源代码库中的modeling_telechat添加TelechatForSequenceClassification方法类(分别参照qwen和星辰自己代码),会分别出现无法加载模型的错误和训练损失不下降的情况。需要AI公司帮忙一起看看怎么支持AutoModelForSequenceClassification任务。

class TelechatForSequenceClassification_tele(TelechatPreTrainedModel):

_tied_weights_keys = ["lm_head.weight"]

#_keys_to_ignore_on_load_missing = [ r"lm_head.weight"]
def __init__(self, config: TelechatConfig):
    super().__init__(config)
    self.num_labels = config.num_labels
    self.transformer = TelechatModel(config)
    self.lm_head = nn.Linear(config.hidden_size, self.num_labels, bias=False)
    self.post_init()

def get_input_embeddings(self):
    return self.model.embed_tokens

def set_input_embeddings(self, value):
    self.model.embed_tokens = value

def get_output_embeddings(self):
    return self.lm_head

def set_output_embeddings(self, new_embeddings: torch.Tensor):
    self.lm_head = new_embeddings

def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
) -> dict:
    if past_key_values:
        input_ids = input_ids[:, -1].unsqueeze(-1)
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs

def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **deprecated_arguments,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    transformer_outputs = self.transformer(
        input_ids,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = transformer_outputs[0]
    lm_logits = self.lm_head(hidden_states)

    if input_ids is not None:
        batch_size = input_ids.shape[0]
    else:
        batch_size = inputs_embeds.shape[0]

    if self.config.pad_token_id is None and batch_size != 1:
        raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
    if self.config.pad_token_id is None:
        sequence_lengths = -1
    else:
        if input_ids is not None:
            # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
            sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(logits.device)
        else:
            sequence_lengths = -1

    pooled_logits = lm_logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

    loss = None
    if labels is not None:
        labels = labels.to(lm_logits.device)
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        #batch_size, seq_length, num_labels = shift_logits.shape
        loss_fct = CrossEntropyLoss()
        #loss = loss_fct(shift_logits.view(batch_size * seq_length, num_labels), shift_labels.view(batch_size * seq_length))
        loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))

    if not return_dict:
        output = (lm_logits,) + transformer_outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutputWithPast(
        loss=loss,
        logits=pooled_logits,#lm_logits,
        past_key_values=transformer_outputs.past_key_values,
        hidden_states=transformer_outputs.hidden_states,
        attentions=transformer_outputs.attentions,
    )
20240919121226

class TelechatForSequenceClassification_qwen(TelechatPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.model = TelechatModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

    # Initialize weights and apply final processing
    self.post_init()

def get_input_embeddings(self):
    return self.model.embed_tokens

def set_input_embeddings(self, value):
    self.model.embed_tokens = value

def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    transformer_outputs = self.model(
        input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = transformer_outputs[0]
    logits = self.score(hidden_states)

    if input_ids is not None:
        batch_size = input_ids.shape[0]
    else:
        batch_size = inputs_embeds.shape[0]

    if self.config.pad_token_id is None and batch_size != 1:
        raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
    if self.config.pad_token_id is None:
        sequence_lengths = -1
    else:
        if input_ids is not None:
            # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
            sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(logits.device)
        else:
            sequence_lengths = -1

    pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

    loss = None
    if labels is not None:
        labels = labels.to(logits.device)
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(pooled_logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(pooled_logits, labels)
    #print('************************', return_dict, loss, self.config)
    if not return_dict:
        output = (pooled_logits,) + transformer_outputs[1:]
        return ((loss,) + output) if loss is not None else output
    #print(f'**********pooled_logits={pooled_logits}***transformer_outputs={transformer_outputs}***********', pooled_logits, transformer_outputs)
    return SequenceClassifierOutputWithPast(
        loss=loss,
        logits=pooled_logits,
        past_key_values=transformer_outputs.past_key_values,
        hidden_states=transformer_outputs.hidden_states,
        attentions=transformer_outputs.attentions,
    )

1

加载模型的代码如下: 2

shunxing12345 commented 5 days ago

好的,我们这边看一下

ge-xing commented 4 days ago

我先用了qwen2模型做分类,数据集是:https://huggingface.co/datasets/knowledgator/Scientific-text-classification 使用deepspeed进行训练,

image

loss是正常的:

image

后续我再去试试telechat-7b,到时候给你反馈。

shunxing12345 commented 4 days ago

这是来自QQ邮箱的假期自动回复邮件。   您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。