facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

ESM-IF sampling for dimer #499

Open adrienchaton opened 1 year ago

adrienchaton commented 1 year ago

Hello,

In the following code https://github.com/facebookresearch/esm/blob/main/esm/inverse_folding/multichain_util.py#L80 I see that a chain can be sampled according to the backbone coordinates of e.g. a dimer.

I am interested in the following case. I have the dimer backbone and first I redesign chain A, second I redesign chain B but I want to take into account the sampled sequence for chain A. Since the decoder is auto-regressive, I could achieve this by swapping the concatenation order --> first the fixed sequence, then padding and last the target sequence to design

Does this make sense to you ? Here is a code snippet:

@torch.inference_mode()
def custom_sample_in_dimer(model, fixed_coord, fixed_seq, target_coord, target_seq, pos_to_sample=None, coord_to_mask=None, temperature=1., padding_length=10, max_repeat=5):
    # here we want to redesign target_seq based on the dimer backbone and the fixed_seq sequence of the other chain
    # != sample_sequence_in_complex (target_seq first in the concatenation followed with pad tokens)
    # + added option to select a subset of residues to redesign in target_seq
    # + added option to mask a subset of the backbone in target_coord
    tot_len = len(fixed_seq)+padding_length+len(target_seq)
    if coord_to_mask is not None:
        for i in coord_to_mask:
            target_coord[i, :] = float('inf')
    pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
    all_coords = np.concatenate([fixed_coord, pad_coords, target_coord], axis=0)
    padding_pattern = list(fixed_seq)+['<pad>']*padding_length
    for i, c in enumerate(target_seq):
        if pos_to_sample is None or i in pos_to_sample:
            padding_pattern.append('<mask>')
        else:
            padding_pattern.append(c)
    device = next(model.parameters()).device  # TODO: check issues with cuda device handling
    valid_sample = False
    while not valid_sample:
        sampled = model.sample(all_coords, partial_seq=padding_pattern, temperature=temperature)  # , device=device)
        sampled = sampled[-len(target_seq):]
        if max_repeat > 0 and count_char_repeats(sampled) >= max_repeat:
            pass
        else:
            valid_sample = True
    return sampled

For the backbone encoder, I think this should be no problem (?). But I am wondering about having padding tokens in the auto-regressive context ... I assume during training the decoder model never sees padding tokens (i.e. they are all to the right side), should this be a problem for inference/sampling ?