generatebio / chroma

A generative model for programmable protein design
Apache License 2.0
627 stars 74 forks source link

Improve sampling behaviour with samples > 1 #16

Open elkoz opened 7 months ago

elkoz commented 7 months ago

This looks like a really great project, thank you for making it!

I encountered some inconveniences when trying to generate a large number of samples with protein_init not set to None and I thought my fixes might be useful for other people.

First, there is a bug in the code right now that leads to only one sample being generated when protein_init is not None, independently of the samples parameter. Here I expand X_unc, C_unc and S_unc along the batch dimension before passing them to the model to fix that. Second, since there is no batching during sampling, generating a large number of samples is inconvenient. In this PR I added a batch_size parameter to make it possible to generate a larger number of outputs with one command.

wujiewang commented 7 months ago

Hey @elkoz Thanks for interests and PR!! The fix looks reasonable to me and can be very useful. However, we have not figured out a protocol to accept external PR yet. Will discuss with the team next week and get back to you!

elkoz commented 7 months ago

Note that when the samples > 1 behavior is fixed, applying the subsequence conditioner with samples > 1 is failing. It should be an issue with the conditioner (e.g. substructure conditioning works correctly) but I'm not sure how to fix it.

Here is the code snippet:

from chroma import Chroma, Protein, conditioners
from chroma.utility.chroma import plane_split_protein
API_KEY = ...

from chroma import api
api.register_key(API_KEY)

chroma = Chroma()

device = "cuda:0"
pdb_id = "7KGK"

protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)

X, C, _ = protein.to_XCS()
selection_string = "namesel infilling_selection"  
residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:,1].tolist()
protein.sys.save_selection(gti=residues_to_design, selname="infilling_selection")

sequence_conditioner = conditioners.SubsequenceConditioner(
    design_model=chroma.design_network, protein=protein, selection=selection_string
).to(device)

infilled_proteins, trajectories = chroma.sample(
    protein_init=protein,
    conditioner=sequence_conditioner,
    langevin_factor=4.0,
    langevin_isothermal=True,
    inverse_temperature=8.0,
    steps=500,
    full_output=True,
    samples=4,
)

And here is the error trace.

{
    "name": "RuntimeError",
    "message": "Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/liza/chroma/test.ipynb Cell 4 line 2
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> sequence_conditioner = conditioners.SubsequenceConditioner(
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a>     design_model=chroma.design_network, protein=protein, selection=selection_string
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> ).to(device)
---> <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> infilled_proteins, trajectories = chroma.sample(
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>     protein_init=protein,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>     conditioner=sequence_conditioner,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>     langevin_factor=4.0,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>     langevin_isothermal=True,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>     inverse_temperature=8.0,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a>     steps=500,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a>     full_output=True,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a>     samples=4,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a> )

File ~/chroma/chroma/models/chroma.py:236, in Chroma.sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, design_ban_S, design_method, design_selection, design_t, temperature_S, temperature_chi, top_p_S, regularization, potts_mcmc_depth, potts_proposal, potts_symmetry_order, verbose)
    233 design_kwargs = {k: input_args[k] for k in input_args if k in design_keys}
    235 # Perform Sampling
--> 236 sample_output = self._sample(**backbone_kwargs)
    238 if full_output:
    239     protein_sample, output_dictionary = sample_output

File ~/chroma/chroma/models/chroma.py:381, in Chroma._sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, **kwargs)
    373 outs = {
    374     \"C\": torch.tensor([], device=X_unc.device), 
    375     \"X_sample\": torch.tensor([], device=X_unc.device),
   (...)
    378     \"Xunc_trajectory\": [torch.tensor([], device=X_unc.device) for i in range(steps)],
    379 }
    380 for b in range(num_batches):
--> 381     outs_ = self.backbone_network.sample_sde(
    382         C_unc[b * batch_size : (b + 1) * batch_size],
    383         X_init=X_unc[b * batch_size : (b + 1) * batch_size],
    384         conditioner=conditioner,
    385         tspan=tspan,
    386         langevin_isothermal=langevin_isothermal,
    387         integrate_func=integrate_func,
    388         sde_func=sde_func,
    389         langevin_factor=langevin_factor,
    390         inverse_temperature=inverse_temperature,
    391         N=steps,
    392         initialize_noise=initialize_noise,
    393         **kwargs,
    394     )
    395     outs[\"C\"] = torch.cat([outs[\"C\"], outs_[\"C\"]], dim=0)
    396     outs[\"X_sample\"] = torch.cat([outs[\"X_sample\"], outs_[\"X_sample\"]], dim=0)

File ~/chroma/chroma/models/graph_backbone.py:187, in GraphBackbone.__init__.<locals>.<lambda>(C, **kwargs)
    185 # Wrap sampling functions
    186 _X0_func = lambda X, C, t: self.denoise(X, C, t)
--> 187 self.sample_sde = lambda C, **kwargs: self.noise_perturb.sample_sde(
    188     _X0_func, C, **kwargs
    189 )
    190 self.sample_baoab = lambda C, **kwargs: self.noise_perturb.sample_baoab(
    191     _X0_func, C, **kwargs
    192 )
    193 self.sample_ode = lambda C, **kwargs: self.noise_perturb.sample_ode(
    194     _X0_func, C, **kwargs
    195 )

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/chroma/chroma/layers/structure/diffusion.py:1208, in DiffusionChainCov.sample_sde(self, X0_func, C, X_init, conditioner, N, tspan, inverse_temperature, langevin_factor, langevin_isothermal, sde_func, integrate_func, initialize_noise, remap_time, remove_drift_translate, remove_noise_translate, align_X0)
   1206         U_test = 0.0
   1207         t_test = torch.tensor([0.0], device=X_init.device)
-> 1208         _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
   1209 else:
   1210     Ct = C

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
    112         if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
    113             raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)

File ~/chroma/chroma/layers/structure/conditioners.py:241, in SubsequenceConditioner.forward(self, X, C, O, U, t)
    239 if self.mask_condition is not None:
    240     priority = 1.0 - self.mask_condition
--> 241 out = self.design_model(X_input, C, self.S_condition, t, priority=priority)
    242 logp_S = out[\"logp_S\"]
    244 if self.mask_condition is not None:

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
    112         if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
    113             raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)

File ~/chroma/chroma/models/graph_design.py:431, in GraphDesign.forward(self, X, C, S, t, sample_noise, permute_idx, priority)
    429 logp_S_potts = None
    430 if self.kwargs[\"predict_S_potts\"]:
--> 431     logp_S_potts = self.decoder_S_potts.loss(
    432         S, node_h, edge_h, edge_idx, mask_i, mask_ij
    433     )
    435 # Sample random permutations and build autoregressive mask
    436 if permute_idx is None:

File ~/chroma/chroma/layers/structure/potts.py:542, in GraphPotts.loss(self, S, node_h, edge_h, edge_idx, mask_i, mask_ij)
    539 h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij)
    541 # Log composite likelihood
--> 542 logp_ij, mask_p_ij = self.log_composite_likelihood(
    543     S,
    544     h,
    545     J,
    546     edge_idx,
    547     mask_i,
    548     mask_ij,
    549     smoothing_alpha=self.label_smoothing if self.training else 0.0,
    550 )
    552 # Map into approximate local likelihoods
    553 logp_i = (
    554     mask_i
    555     * torch.sum(mask_p_ij * logp_ij, dim=-1)
    556     / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3)
    557 )

File ~/chroma/chroma/layers/structure/potts.py:458, in GraphPotts.log_composite_likelihood(self, S, h, J, edge_idx, mask_i, mask_ij, smoothing_alpha)
    454 num_batch, num_residues, num_k, num_states, _ = list(J.size())
    456 # Gather J clamped at j
    457 # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i]
--> 458 S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx)
    459 S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1)
    460 # (B,i,j,A_i)

File ~/chroma/chroma/layers/graph.py:677, in collect_neighbors(node_h, edge_idx)
    675 idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
    676 idx_flat = idx_flat.expand(-1, -1, num_features)
--> 677 neighbor_h = torch.gather(node_h, 1, idx_flat)
    678 neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
    679 return neighbor_h

RuntimeError: Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1"
}
wujiewang commented 7 months ago

Yeah, not all conditioners have been tested for batched sampling. Will need to figure out this. Thanks for flagging!