facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.29k stars 644 forks source link

Inverse Folding Feature Request: Fixing parts of the protein sequence #236

Closed chanjed closed 2 years ago

chanjed commented 2 years ago

I was wondering whether there are any plans to add a feature where parts of the sequence being generated are fixed while others are sampled such that small segments of the protein can be designed. Thanks!

naailkhan28 commented 2 years ago

It looks like model.sample() now takes a partial_seq argument which can be used to supply a fixed sequence, with any variable regions replaced with <mask> tokens

tomsercu commented 2 years ago

Totally right, thanks @naailkhan28 . @chanjed let us know if this indeed resolves your feature request!

chanjed commented 2 years ago

Thanks! Resolved

martinpacesa commented 2 years ago

I have managed to get this to work by adding a new input of comma separated amino acid positions (starting from 1) and adjusting the functions for single sequence and complex sampling:

def sample_seq_singlechain(model, alphabet, args):
    coords, native_seq = esm.inverse_folding.util.load_coords(args.pdbfile, args.chain)

    #fix residues
    if args.fixed is not None:
        fixed_res = args.fixed.split(',')
        fixed_res = [int(i)-1 for i in fixed_res]
        #if current position is not in fixed_res then change into <mask>
        fixedprot = []
        for i, aa in enumerate(native_seq):
            if i in fixed_res:
                fixedprot.append(aa)
            else:
                fixedprot.append('<mask>')
    else:
        fixedprot = None
    #end fix residues

    print('Native sequence loaded from structure file:')
    print(native_seq)

    print(f'Saving sampled sequences to {args.outpath}.')

    Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
    with open(args.outpath, 'w') as f:
        for i in range(args.num_samples):
            print(f'\nSampling.. ({i+1} of {args.num_samples})')
            sampled_seq = model.sample(coords, partial_seq=fixedprot, temperature=args.temperature)
            print('Sampled sequence:')
            print(sampled_seq)
            f.write(f'>sampled_seq_{i+1}\n')
            f.write(sampled_seq + '\n')

            recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
            print('Sequence recovery:', recovery)

def sample_seq_multichain(model, alphabet, args):
    structure = esm.inverse_folding.util.load_structure(args.pdbfile)
    coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)

    target_chain_id = args.chain
    native_seq = native_seqs[target_chain_id]
    print('Native sequence loaded from structure file:')
    print(native_seq)
    print('\n')

    print(f'Saving sampled sequences to {args.outpath}.')

    Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
    with open(args.outpath, 'w') as f:
        for i in range(args.num_samples):
            print(f'\nSampling.. ({i+1} of {args.num_samples})')

            #fix residues
            target_chain_len = coords[target_chain_id].shape[0]
            all_coords = esm.inverse_folding.multichain_util._concatenate_coords(coords, target_chain_id)

            # Supply padding tokens for other chains to avoid unused sampling for speed
            padding_pattern = ['<pad>'] * all_coords.shape[0]
            if args.fixed is not None:
                fixed_res = args.fixed.split(',')
                fixed_res = [int(i)-1 for i in fixed_res]
            for n in range(target_chain_len):
                if n in fixed_res:
                    padding_pattern[n] = native_seq[n]
                else:
                    padding_pattern[n] = '<mask>'
            sampled = model.sample(all_coords, partial_seq=padding_pattern,
                    temperature=args.temperature)
            sampled_seq = sampled[:target_chain_len]
            #end fixed residues

            print('Sampled sequence:')
            print(sampled_seq)
            f.write(f'>sampled_seq_{i+1}\n')
            f.write(sampled_seq + '\n')

            recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
            print('Sequence recovery:', recovery)
Xiaoxiao0606 commented 5 months ago

I have already modified the partial_seq parameter and passed in such input:

WX20240628-172003@2x

The middle part is masked, and I hope to design this part. but the output I received is like this:

WX20240628-172033@2x

Could you please advise on further modifications? Thank you!