Closed chanjed closed 2 years ago
It looks like model.sample()
now takes a partial_seq
argument which can be used to supply a fixed sequence, with any variable regions replaced with <mask>
tokens
Totally right, thanks @naailkhan28 . @chanjed let us know if this indeed resolves your feature request!
Thanks! Resolved
I have managed to get this to work by adding a new input of comma separated amino acid positions (starting from 1) and adjusting the functions for single sequence and complex sampling:
def sample_seq_singlechain(model, alphabet, args):
coords, native_seq = esm.inverse_folding.util.load_coords(args.pdbfile, args.chain)
#fix residues
if args.fixed is not None:
fixed_res = args.fixed.split(',')
fixed_res = [int(i)-1 for i in fixed_res]
#if current position is not in fixed_res then change into <mask>
fixedprot = []
for i, aa in enumerate(native_seq):
if i in fixed_res:
fixedprot.append(aa)
else:
fixedprot.append('<mask>')
else:
fixedprot = None
#end fix residues
print('Native sequence loaded from structure file:')
print(native_seq)
print(f'Saving sampled sequences to {args.outpath}.')
Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
with open(args.outpath, 'w') as f:
for i in range(args.num_samples):
print(f'\nSampling.. ({i+1} of {args.num_samples})')
sampled_seq = model.sample(coords, partial_seq=fixedprot, temperature=args.temperature)
print('Sampled sequence:')
print(sampled_seq)
f.write(f'>sampled_seq_{i+1}\n')
f.write(sampled_seq + '\n')
recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)
def sample_seq_multichain(model, alphabet, args):
structure = esm.inverse_folding.util.load_structure(args.pdbfile)
coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)
target_chain_id = args.chain
native_seq = native_seqs[target_chain_id]
print('Native sequence loaded from structure file:')
print(native_seq)
print('\n')
print(f'Saving sampled sequences to {args.outpath}.')
Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
with open(args.outpath, 'w') as f:
for i in range(args.num_samples):
print(f'\nSampling.. ({i+1} of {args.num_samples})')
#fix residues
target_chain_len = coords[target_chain_id].shape[0]
all_coords = esm.inverse_folding.multichain_util._concatenate_coords(coords, target_chain_id)
# Supply padding tokens for other chains to avoid unused sampling for speed
padding_pattern = ['<pad>'] * all_coords.shape[0]
if args.fixed is not None:
fixed_res = args.fixed.split(',')
fixed_res = [int(i)-1 for i in fixed_res]
for n in range(target_chain_len):
if n in fixed_res:
padding_pattern[n] = native_seq[n]
else:
padding_pattern[n] = '<mask>'
sampled = model.sample(all_coords, partial_seq=padding_pattern,
temperature=args.temperature)
sampled_seq = sampled[:target_chain_len]
#end fixed residues
print('Sampled sequence:')
print(sampled_seq)
f.write(f'>sampled_seq_{i+1}\n')
f.write(sampled_seq + '\n')
recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)
I have already modified the partial_seq parameter and passed in such input:
The middle part is masked, and I hope to design this part. but the output I received is like this:
Could you please advise on further modifications? Thank you!
I was wondering whether there are any plans to add a feature where parts of the sequence being generated are fixed while others are sampled such that small segments of the protein can be designed. Thanks!