generatebio / chroma

A generative model for programmable protein design
Apache License 2.0
627 stars 74 forks source link

Strange output from inpainting #24

Closed hnisonoff closed 7 months ago

hnisonoff commented 7 months ago

I have been playing around with inpainting and after replicating the examples that you provided I tried it out on my own protein.

I found that while inpainting a region of 40 amino acids out of 157, that the output from Chroma was a protein in which it appeared to explode with a completely broken chain. I then attempted to modify the various SDE / Langevin dynamics parameters but these did not result in more stable designs. In the zipped folder attached, I have the following files:

  1. original_protein.pdb is the starting protein from which I am trying to inpaint a subset of the residues (indices 80-139)
  2. large_inpainting_final.cif is the structure output from the inpainting design
  3. large_inpainting_traj.cif is the trajectory

To debug this further I tried reducing the size of the inpainting. I noticed when inpainting only 10 residues, Chroma provided reasonable looking designs. These files are:

  1. large_inpainting_traj.cif
  2. small_inpainting_traj.cif

Next, I increased the inpainting size from 10 to 20 residues. This time, things did not explode, however, Chroma inpainted a region that forms a knot with the rest of the protein and clashes quite significantly. These files are:

  1. medium_inpainting_final.cif
  2. medium_inpainting_traj.cif

The code I used to do the inpainting is:

device='cuda'

protein = Protein('original_protein.pdb', device=device)
residues_to_fix = [i for i in range(156) if i < 80 or i > 139]
protein.sys.save_selection(gti=residues_to_fix, selname="infilling_selection")

conditioner = conditioners.SubstructureConditioner(
    protein,
    backbone_model=chroma.backbone_network,
    selection = 'namesel infilling_selection').to(device)

infilled_protein = chroma.sample(
    protein_init=protein,
    conditioner=conditioner,
    langevin_factor=4.0,
    langevin_isothermal=True,
    inverse_temperature=8.0,
    full_output=True,
    steps=500)

As mentioned above, I also tried modifying Langevin_factor, Langevin_isothermal, and inverse_temperature.

If you are able to provide any insight into why this may be happening or provide guidance on what I may be doing wrong, that would be great. It is also possible that this is just a case where inpainting seems to not work well for whatever reason (sampling is hard, protein is out of distribution enough). Thank you!

for_github.zip

wujiewang commented 7 months ago

@hnisonoff Thanks for the feedback!

can you parse sde_func='langevin' as additional sampling args. This enables preconditioned langevin dynamics instead of reverse diffusion, which should be the default for infilling.

hnisonoff commented 7 months ago

@wujiewang thank you for the help! I added sde_func='langevin' and now it is working!

Should the example in the ChromaAPI.ipynb be modified to reflect this? Also, in that example I see the following line:

wujiewang commented 7 months ago

Yes, let me update the code! will do a small 1.0.1 release as we gather more improvements!

Thanks for bringing this to our attention! A quiet subtle point!

gha2012 commented 6 months ago

Hi, thank you for chroma! I also have problems with infilling. My plan is to take an input structure and design a new domain of specified size (e.g. 100 AA) into one of its loops.

This is the code that I use, which is a combination of code snippets that I have found in your examples:

import random
import sys
import torch
from chroma import Chroma
from chroma import Protein, conditioners, api
from chroma.constants.sequence import AA20_3_TO_1

device = 'cuda' if torch.cuda.is_available() else 'cpu'

outputPrefix="extensionDesign"
inputFile="input.pdb"
numberOfDesigns=1
minLength=100
maxLength=150
cathString="-1"
designT=0.0
insert_position=76

# define aa dictionary
aa_to_num = {}
num_to_aa = {}
for i, aaa in enumerate(list(AA20_3_TO_1)):
    aa_to_num[AA20_3_TO_1[aaa]] = i
    num_to_aa[i] = AA20_3_TO_1[aaa]

for thisDesign in range(numberOfDesigns):
    chroma = Chroma()
    protein = Protein(f"%s" %inputFile, device=device) # pdb with target structure
    X, C, S = protein.to_XCS()

    len_extension = random.randint(minLength, maxLength)
    len_inputStructure = (C == 1).sum().item()

    #prepare empty tensors with correct dimensions
    X_new = torch.zeros(1, len_inputStructure + len_extension, 4, 3).cuda()
    C_new = torch.full((1, len_inputStructure + len_extension), 1).cuda()
    S_new = torch.full((1, len_inputStructure + len_extension), 0).cuda()

    #fill tensors with slices of old structures and of a dummy structure for the part that shall be designed
    X_new[:, 0:insert_position - 1, :, :] = X[:, 0:insert_position - 1, :, :]
    X_new[:, insert_position:insert_position + len_extension, :, :] = torch.zeros(1, len_extension, 4, 3).cuda()
    X_new[:, insert_position+len_extension + 1:len_extension+len_inputStructure, :, :] = X[:,insert_position + 1:len_inputStructure, :, :]

    C_new[:, 0:insert_position - 1] = C[:, 0:insert_position - 1]
    C_new[:, insert_position:insert_position + len_extension] = torch.full((1, len_extension),1).cuda()
    C_new[:, insert_position+len_extension + 1:len_extension+len_inputStructure] = C[:,insert_position + 1:len_inputStructure]

    S_new[:, 0:insert_position - 1] = S[:, 0:insert_position - 1]
    S_new[:, insert_position:insert_position + len_extension] = torch.full((1, len_extension),0).cuda()
    S_new[:, insert_position+len_extension + 1:len_extension+len_inputStructure] = S[:,insert_position + 1:len_inputStructure]

    del X,C,S
    protein = Protein(X_new, C_new, S_new, device=device)
    X, C, S = protein.to_XCS()

    L_extension = len_extension
    L_original = len_inputStructure
    L_total = L_extension + len_inputStructure
    print(L_total)
    print("C.shape[-1]: ", C.shape[-1])
    assert L_total==C.shape[-1]

    # keep original seq by providing the mask
    mask_aa = torch.Tensor(L_total * [[1] * 20]) #all ones
    buffer = 5
    for i in range(L_total):
        if i < insert_position - buffer or i > insert_position + len_extension + buffer:
            mask_aa[i] = torch.Tensor([0] * 20)
            #print(S[0][i].item())
            mask_aa[i][S[0][i].item()] = 1 #This gets the item form the sequence tensor and sets the index in the mask to 1
        else:
            mask_aa[i] = torch.Tensor([1] * 20)

    mask_aa = mask_aa[None].cuda()
    residues_to_keep1 = [i for i in range(insert_position)]
    residues_to_keep2 =[i for i in range(insert_position + L_extension, L_total)]
    residues_to_keep = residues_to_keep1 + residues_to_keep2 #[i + L_extension + 1  for i in range(L_original)]
    #residues_to_design = [i for i in range(insert_position, insert_position + L_extension)]

    protein.sys.save_selection(gti=residues_to_keep, selname="original")
    conditionerList = []
    substructureConditioner = conditioners.SubstructureConditioner(
            protein,
            backbone_model=chroma.backbone_network,
            selection = 'namesel original', 
            rg=True).to(device) #rg=True avoids clashes and breaks
    conditionerList.append(substructureConditioner)
    conditioner = conditioners.ComposedConditioner(conditionerList)
    protein,traj = chroma.sample(
        protein_init=protein,
        conditioner=conditioner,
        design_selection = mask_aa,
        langevin_factor=4.0,
        langevin_isothermal=True,
        inverse_temperature=8.0,
        sde_func='langevin',
        full_output=True,
        #design_t=designT, #this can be set between 0.0 (paper) and 0.5 (default see github)
        steps=500,
    )

    protein.to_PDB(f"%s_%i.pdb" %(outputPrefix, thisDesign))

It sort of works. The original structure is left untouched, but the new design is not forming a globular domain but rather very strangely wrapps around the input structure. I wonder if I do anything completely wrong.

Thanks for your help!

gha2012 commented 6 months ago

Do I perhaps have to combine this with a shape conditioner?

wujiewang commented 6 months ago

Hi @gha2012

What you described is possible, the results can depend on the geometry of the motif.

I think it is good idea to try out the shape conditioner. You can also write a conditioner that uses a restraint function that penalizes structures with large radius of gyration to make it more globular.