Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.93k stars 3.34k forks source link

training=False when use a pretrained model like BERT #20128

Closed huangfu170 closed 1 month ago

huangfu170 commented 1 month ago

Bug description

I use bert model from transformers for text classification, I found that the whole model's training=True and the bert model's training=False in training_step(), I want to know why and thanks for any kind help. Here is my code: `class TestModel(pl.LightningModule): def init(self,bert_config): super().init() self.bert=BertModel.from_pretrained(bert_config._name_or_path) self.project_final=nn.Linear(768,80) def forward(self,input_ids,attention_mask): output=self.bert(input_ids,attention_mask=attention_mask) output=output[0][:,0,:] output=self.project_final(output) return output def training_step(self,batch,batch_idx): input_ids=batch["input_ids"] attention_mask=batch["attention_mask"] label=batch["label"] output=self(input_ids,attention_mask) loss=loss_function(output,label) return loss def configure_optimizers(self): optimizer=torch.optim.Adam(self.parameters(),lr=cf.lr) return optimizer

data_module=ContrastiveDataModule()
model=TestModel(bert_config)
trainer=pl.Trainer(max_epochs=100,accelerator='gpu',devices=[2])
trainer.fit(model,data_module)

`

What version are you seeing the problem on?

v2.2

How to reproduce the bug

class TestModel(pl.LightningModule):
    def __init__(self,bert_config):
        super().__init__()
        self.bert=BertModel.from_pretrained(bert_config._name_or_path)
        self.project_final=nn.Linear(768,80)
    def forward(self,input_ids,attention_mask):
        output=self.bert(input_ids,attention_mask=attention_mask)
        output=output[0][:,0,:]
        output=self.project_final(output)
        return output
    def training_step(self,batch,batch_idx):
        input_ids=batch["input_ids"]
        attention_mask=batch["attention_mask"]
        label=batch["label"]
        output=self(input_ids,attention_mask)
        loss=loss_function(output,label)
        return loss
    def configure_optimizers(self):
        optimizer=torch.optim.Adam(self.parameters(),lr=cf.lr)
        return optimizer

    data_module=ContrastiveDataModule()
    model=TestModel(bert_config)
    trainer=pl.Trainer(max_epochs=100,accelerator='gpu',devices=[2])
    trainer.fit(model,data_module)

Error messages and logs

self.traing=True and self.bert.training=False

Environment

Current environment ``` #- PyTorch Lightning Version :2.2.0 #- PyTorch Version : 1.13.1 #- Python version: 3.9.19 #- OS: Linux #- CUDA/cuDNN version: 12.0 #- GPU models and configuration: 5 NVIDIA 3090 gpus on 1 node #- How you installed Lightning(`conda`, `pip`, source): pip ```

More info

The model have two module, the project_final module's training is True as the whole "self" model, but the pretrained model differs

cc @borda

huangfu170 commented 1 month ago

It seems that I should use "self.bert=BertModel.from_pretrained(bert_config._name_or_path).train()", which used in https://lightning.ai/lightning-ai/studios/text-classification-with-pytorch-lightning. But it is weird to manually turn it to train mode even though the whole model automatically turn train mode in training_step. And the example in doc does not mention it either in https://lightning.ai/docs/pytorch/stable/advanced/transfer_learning.html#example-bert-nlp. It may write more clearly if updating the doc or turn the pretrained model to train mode automatically.

awaelchli commented 1 month ago

Hey @huangfu170 You can look at my response here for why it's like that: #20105 Yes it would be great to update the example, let's do it! Contributions for that are welcome.

huangfu170 commented 1 month ago

Hey @huangfu170 You can look at my response here for why it's like that: #20105 Yes it would be great to update the example, let's do it! Contributions for that are welcome.

Oh, I have not looked this PR, thank you very much for help we found the reason, the documents can to be changed for this for clearity. Thank you again! Please close the issue.