wells-wood-research / timed-design

Protein Sequence Design with Deep Learning and Tooling like Monte Carlo Sampling and Analysis
46 stars 11 forks source link

Charge and Polar have Ca and Cb channels inverted #64

Closed universvm closed 4 months ago

universvm commented 8 months ago

Due to a bug in the training of the models, the training frames for the Charge and Polar models had the order:

[C, N, O, Cb, Ca, CHARGE/POLARITY] 

Rather than

[C, N, O, Ca, Cb, CHARGE/POLARITY]

Which makes the Charge and Polar models unusable. A very hacky quick-fix of this would involve swapping the channels order at prediction time until we are able to spend time and computation retraining the models.

I propose to modify the function load_batch at https://github.com/wells-wood-research/timed-design/blob/abc6afa63ee0fcd467b11f68be4accec432aac43/design_utils/utils.py#L487

to include something like this ONLY FOR CHARGE AND POLAR MODELS:

# Extract frame from batch:
for i, (pdb_code, chain_id, residue_id, _) in enumerate(data_point_batch):
    # Extract frame:
    residue_frame = np.asarray(dataset[pdb_code][chain_id][residue_id][()])

    # Check if the frame has the correct shape (final dimension is 6) for swapping:
    if residue_frame.ndim == 4 and residue_frame.shape[-1] == 6:
        # Swap only the 4th and 5th channels (index 3 and 4)
        residue_frame[..., 3], residue_frame[..., 4] = residue_frame[..., 4].copy(), residue_frame[..., 3].copy()

    X[i] = residue_frame

    # Extract residue label:
    y[i] = dataset[pdb_code][chain_id][residue_id].attrs["encoded_residue"]

The charge and polar have a final dimension of 6

sunal1996 commented 8 months ago

In case we took too long to train things, and in case we wanted to allocate something else to that 6th channel, should we have it like this:

encoder_list =  dataset.attrs["atom_encoder"]

for i, (pdb_code, chain_id, residue_id, _) in enumerate(data_point_batch):
    # Extract frame:
    residue_frame = np.asarray(dataset[pdb_code][chain_id][residue_id][()])

    # Check if the frame has the correct shape (final dimension is 6) for swapping:
    if residue_frame.ndim == 4 and ("Q" or "P" in encoder_list):
        # Swap only the 4th and 5th channels (index 3 and 4)
        residue_frame[..., 3], residue_frame[..., 4] = residue_frame[..., 4].copy(), residue_frame[..., 3].copy()

    X[i] = residue_frame

    # Extract residue label:
    y[i] = dataset[pdb_code][chain_id][residue_id].attrs["encoded_residue"]
universvm commented 4 months ago

Comparison_summary 3.pdf