I am interested in the following case. I have the dimer backbone and first I redesign chain A, second I redesign chain B but I want to take into account the sampled sequence for chain A. Since the decoder is auto-regressive, I could achieve this by swapping the concatenation order
--> first the fixed sequence, then padding and last the target sequence to design
Does this make sense to you ? Here is a code snippet:
@torch.inference_mode()
def custom_sample_in_dimer(model, fixed_coord, fixed_seq, target_coord, target_seq, pos_to_sample=None, coord_to_mask=None, temperature=1., padding_length=10, max_repeat=5):
# here we want to redesign target_seq based on the dimer backbone and the fixed_seq sequence of the other chain
# != sample_sequence_in_complex (target_seq first in the concatenation followed with pad tokens)
# + added option to select a subset of residues to redesign in target_seq
# + added option to mask a subset of the backbone in target_coord
tot_len = len(fixed_seq)+padding_length+len(target_seq)
if coord_to_mask is not None:
for i in coord_to_mask:
target_coord[i, :] = float('inf')
pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
all_coords = np.concatenate([fixed_coord, pad_coords, target_coord], axis=0)
padding_pattern = list(fixed_seq)+['<pad>']*padding_length
for i, c in enumerate(target_seq):
if pos_to_sample is None or i in pos_to_sample:
padding_pattern.append('<mask>')
else:
padding_pattern.append(c)
device = next(model.parameters()).device # TODO: check issues with cuda device handling
valid_sample = False
while not valid_sample:
sampled = model.sample(all_coords, partial_seq=padding_pattern, temperature=temperature) # , device=device)
sampled = sampled[-len(target_seq):]
if max_repeat > 0 and count_char_repeats(sampled) >= max_repeat:
pass
else:
valid_sample = True
return sampled
For the backbone encoder, I think this should be no problem (?). But I am wondering about having padding tokens in the auto-regressive context ... I assume during training the decoder model never sees padding tokens (i.e. they are all to the right side), should this be a problem for inference/sampling ?
Hello,
In the following code https://github.com/facebookresearch/esm/blob/main/esm/inverse_folding/multichain_util.py#L80 I see that a chain can be sampled according to the backbone coordinates of e.g. a dimer.
I am interested in the following case. I have the dimer backbone and first I redesign chain A, second I redesign chain B but I want to take into account the sampled sequence for chain A. Since the decoder is auto-regressive, I could achieve this by swapping the concatenation order --> first the fixed sequence, then padding and last the target sequence to design
Does this make sense to you ? Here is a code snippet:
For the backbone encoder, I think this should be no problem (?). But I am wondering about having padding tokens in the auto-regressive context ... I assume during training the decoder model never sees padding tokens (i.e. they are all to the right side), should this be a problem for inference/sampling ?