zjunlp / OntoProtein

[ICLR 2022] OntoProtein: Protein Pretraining With Gene Ontology Embedding
MIT License
138 stars 22 forks source link

run_contact.sh error #8

Closed sa5r closed 2 years ago

sa5r commented 2 years ago

Hi, I setup fresh environment for running the script and when I run [run_contact.sh] I get the following error in "contact-ontoprotein.out"

Running Prediction Num examples = 40 Batch size = 1 Traceback (most recent call last): File "run_downstream.py", line 286, in main() File "run_downstream.py", line 281, in main predictions_family, input_ids_family, metrics_family = trainer.predict(test_dataset) File "/home/sakher/miniconda3/envs/onto2/lib/python3.8/site-packages/transformers/trainer.py", line 2358, in predict output = eval_loop( File "/data3/sakher/onto2/OntoProtein/src/benchmark/trainer.py", line 217, in evaluation_loop loss, logits, labels, prediction_score = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) File "/data3/sakher/onto2/OntoProtein/src/benchmark/trainer.py", line 50, in prediction_step prediction_score['precision_at_l2'] = logits[3]['precision_at_l2'] KeyError: 'precision_at_l2'

Alexzhuan commented 2 years ago

Hi,

For reporting the metrics P@K when taking different values for K, in which the metrics P@K are precisions for the top K contacts, we made some changes in the tape library. The changes in tape.models.modeling_utils.py are as follows:


# line 843
class PairwiseContactPredictionHead(nn.Module):

    def __init__(self, hidden_size: int, ignore_index=-100):
        super().__init__()
        self.predict = nn.Sequential(
            nn.Dropout(), nn.Linear(2 * hidden_size, 2))
        self._ignore_index = ignore_index

    def forward(self, inputs, sequence_lengths, targets=None):
        prod = inputs[:, :, None, :] * inputs[:, None, :, :]
        diff = inputs[:, :, None, :] - inputs[:, None, :, :]
        pairwise_features = torch.cat((prod, diff), -1)
        prediction = self.predict(pairwise_features)
        prediction = (prediction + prediction.transpose(1, 2)) / 2
        prediction = prediction[:, 1:-1, 1:-1].contiguous()  # remove start/stop tokens
        outputs = (prediction,)

        if targets is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
            contact_loss = loss_fct(
                prediction.view(-1, 2), targets.view(-1))
            metrics = {'precision_at_l5':
                       self.compute_precision_at_l5(sequence_lengths, prediction, targets),
                       'precision_at_l2':
                       self.compute_precision_at_l2(sequence_lengths, prediction, targets),
                       'precision_at_l':
                       self.compute_precision_at_l(sequence_lengths, prediction, targets)}
            loss_and_metrics = (contact_loss, metrics)
            outputs = (loss_and_metrics,) + outputs

        return outputs

    def compute_precision_at_l5(self, sequence_lengths, prediction, labels):
        with torch.no_grad():
            valid_mask = labels != self._ignore_index
            seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device)
            x_ind, y_ind = torch.meshgrid(seqpos, seqpos)
            valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0)
            probs = F.softmax(prediction, 3)[:, :, :, 1]
            valid_mask = valid_mask.type_as(probs)
            correct = 0
            total = 0
            for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask):
                masked_prob = (prob * mask).view(-1)
                most_likely = masked_prob.topk(length // 5, sorted=False)
                selected = label.view(-1).gather(0, most_likely.indices)
                correct += selected.sum().float()
                total += selected.numel()
            return correct / total

    def compute_precision_at_l2(self, sequence_lengths, prediction, labels):
        with torch.no_grad():
            valid_mask = labels != self._ignore_index
            seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device)
            x_ind, y_ind = torch.meshgrid(seqpos, seqpos)
            valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0)
            probs = F.softmax(prediction, 3)[:, :, :, 1]
            valid_mask = valid_mask.type_as(probs)
            correct = 0
            total = 0
            for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask):
                masked_prob = (prob * mask).view(-1)
                most_likely = masked_prob.topk(length // 2, sorted=False)
                selected = label.view(-1).gather(0, most_likely.indices)
                correct += selected.sum().float()
                total += selected.numel()
            return correct / total

    def compute_precision_at_l(self, sequence_lengths, prediction, labels):
        with torch.no_grad():
            valid_mask = labels != self._ignore_index
            seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device)
            x_ind, y_ind = torch.meshgrid(seqpos, seqpos)
            valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0)
            probs = F.softmax(prediction, 3)[:, :, :, 1]
            valid_mask = valid_mask.type_as(probs)
            correct = 0
            total = 0
            for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask):
                masked_prob = (prob * mask).view(-1)
                most_likely = masked_prob.topk(length, sorted=False)
                selected = label.view(-1).gather(0, most_likely.indices)
                correct += selected.sum().float()
                total += selected.numel()
            return correct / total
sa5r commented 2 years ago

Thanks! That made the prediction work, however, I got very low accuracy results, it could be reasoned by the following warning. Do you think running run_pretrain code fine-tune the mode for this downstream task?

04/26/2022 23:14:45 - INFO - main - Task name: contact, output mode: token-level-classification Some weights of the model checkpoint at ./model/contact/OntoproteinModel were not used when initializing BertForOntoProteinContactPrediction: ['cls.predictions.decoder.bias', 'bert.pooler.dense.weight', 'cls.predictions.transform.dense.bias', 'bert.pooler.dense.bias', 'cls.predictions.decoder.weight', 'pooler.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'pooler.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']

  • This IS expected if you are initializing BertForOntoProteinContactPrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
  • This IS NOT expected if you are initializing BertForOntoProteinContactPrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForOntoProteinContactPrediction were not initialized from the model checkpoint at ./model/contact/OntoproteinModel and are newly initialized: ['classifier.weight', 'predict.predict.1.weight', 'classifier.bias', 'predict.predict.1.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. 04/26/2022 23:15:42 - INFO - main - Test
Alexzhuan commented 2 years ago

I think that the reason for the poor performance is that you evaluate the downstream task without fine-tuning the pre-trained model. When evaluating a downstream task, you need to finetune the pre-trained model.

sa5r commented 2 years ago

So I can fine-tune the model by running run_pretrain.py , but does this fine-tune for all the downstream tasks? I don't see any parameter the specifies the downstream task to be fine-tuned for.

Alexzhuan commented 2 years ago

You need to set --do_train as True in the script script/run_{task}.sh and run the script for fine-tuning the pre-trained model on the downstream task dataset.

sa5r commented 2 years ago

I set --do_train to True but I get the process killed due to an error that I am trying to discover now, however, it could be reasoned by some missing graph files, like go_graph.txt go_detail.txt go_leaf.txt that are not found in the provided ProteinKG25 file but generated by gen_onto_protein_data.py in line 290 of the script, noting that I want only to depend on the provided graph and not to generate one from the script.

Alexzhuan commented 2 years ago

For the error that occurred during the training, could you present detailed error information?

The 'missing' files are only intermediate products for analysis of the training data, which are not used in training or inference. It should be noted that fine-tuning a pre-trained model on the downstream tasks merely leverages the checkpoint of language model which could be loaded with the HuggingFace API.