We've identified an issue where per-residue predictions for protein sequences differ when processed as part of a batch versus individually. This inconsistency affects especially the CNN model and likely also the LightAttention model, particularly for residues at the end of sequences.
Key observations:
Batch processing vs. single input:
* Predictions for the same sequence differ when processed in a batch compared to individually.
* Differences are more pronounced for residues at the end of sequences.
Padding effects:
* The current implementation doesn't properly handle padded sequences in batches.
* This leads to inconsistencies, especially for shorter sequences in a batch.
Model architecture considerations:
* Both CNN and LightAttention models are affected.
* The issue is more noticeable in the CNN model due to its convolutional nature.
Normalization layers:
* BatchNorm layers in the LightAttention model contributed to the discrepancy.
* Replacing BatchNorm with LayerNorm partially mitigated the issue.
Proposed solutions:
Implement mask-aware processing:
* Introduce a mask to identify non-padded elements in batched inputs.
* Modify model forward passes to respect this mask.
Consistent padding strategy:
* Implement custom padding that doesn't affect original sequence lengths.
* Use F.pad for explicit control over padding in convolutional layers.
Consider adding a test function like this to test_inference.py:
We've identified an issue where per-residue predictions for protein sequences differ when processed as part of a batch versus individually. This inconsistency affects especially the CNN model and likely also the LightAttention model, particularly for residues at the end of sequences.
Key observations:
Proposed solutions:
Issue written together with Anthropic Claude-Sonnett 3.5