sacdallago / biotrainer

Biological prediction models made simple.
https://biocentral.cloud/app
Academic Free License v3.0
34 stars 8 forks source link

Per-residue predictions differ between batches and single input embeddings #100

Closed SebieF closed 3 months ago

SebieF commented 3 months ago

We've identified an issue where per-residue predictions for protein sequences differ when processed as part of a batch versus individually. This inconsistency affects especially the CNN model and likely also the LightAttention model, particularly for residues at the end of sequences.

Key observations:

Batch processing vs. single input:
    * Predictions for the same sequence differ when processed in a batch compared to individually.
    * Differences are more pronounced for residues at the end of sequences.

Padding effects:
    * The current implementation doesn't properly handle padded sequences in batches.
    * This leads to inconsistencies, especially for shorter sequences in a batch.

Model architecture considerations:
    * Both CNN and LightAttention models are affected.
    * The issue is more noticeable in the CNN model due to its convolutional nature.

Normalization layers:
    * BatchNorm layers in the LightAttention model contributed to the discrepancy.
    * Replacing BatchNorm with LayerNorm partially mitigated the issue.

Proposed solutions:

Implement mask-aware processing:
    * Introduce a mask to identify non-padded elements in batched inputs.
    * Modify model forward passes to respect this mask.

Consistent padding strategy:
    * Implement custom padding that doesn't affect original sequence lengths.
    * Use F.pad for explicit control over padding in convolutional layers.

Consider adding a test function like this to test_inference.py:
        def test_single(self):
            r2c_dict = self.inferencer_r2c.from_embeddings(self.per_residue_embeddings,
                                                           include_probabilities=True)

            r2c_single_result_dict = {}
            for seq_id, emb in self.per_residue_embeddings.items():
                r2c_single_result_emb = self.inferencer_r2c.from_embeddings({seq_id: emb}, include_probabilities=True)
                r2c_single_result = r2c_single_result_emb["mapped_probabilities"][seq_id]
                r2c_single_result_dict[seq_id] = r2c_single_result

            prediction_errors = self._compare_predictions(r2c_dict["mapped_probabilities"],
                                                          r2c_single_result_dict)
            if len(prediction_errors) > 0:
                print(prediction_errors)
            self.assertTrue(len(prediction_errors) == 0)

Issue written together with Anthropic Claude-Sonnett 3.5