Open elkoz opened 7 months ago
Hey @elkoz Thanks for interests and PR!! The fix looks reasonable to me and can be very useful. However, we have not figured out a protocol to accept external PR yet. Will discuss with the team next week and get back to you!
Note that when the samples > 1 behavior is fixed, applying the subsequence conditioner with samples > 1 is failing. It should be an issue with the conditioner (e.g. substructure conditioning works correctly) but I'm not sure how to fix it.
Here is the code snippet:
from chroma import Chroma, Protein, conditioners
from chroma.utility.chroma import plane_split_protein
API_KEY = ...
from chroma import api
api.register_key(API_KEY)
chroma = Chroma()
device = "cuda:0"
pdb_id = "7KGK"
protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)
X, C, _ = protein.to_XCS()
selection_string = "namesel infilling_selection"
residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:,1].tolist()
protein.sys.save_selection(gti=residues_to_design, selname="infilling_selection")
sequence_conditioner = conditioners.SubsequenceConditioner(
design_model=chroma.design_network, protein=protein, selection=selection_string
).to(device)
infilled_proteins, trajectories = chroma.sample(
protein_init=protein,
conditioner=sequence_conditioner,
langevin_factor=4.0,
langevin_isothermal=True,
inverse_temperature=8.0,
steps=500,
full_output=True,
samples=4,
)
And here is the error trace.
{
"name": "RuntimeError",
"message": "Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/home/liza/chroma/test.ipynb Cell 4 line 2
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> sequence_conditioner = conditioners.SubsequenceConditioner(
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> design_model=chroma.design_network, protein=protein, selection=selection_string
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> ).to(device)
---> <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> infilled_proteins, trajectories = chroma.sample(
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a> protein_init=protein,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a> conditioner=sequence_conditioner,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a> langevin_factor=4.0,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a> langevin_isothermal=True,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a> inverse_temperature=8.0,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a> steps=500,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a> full_output=True,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a> samples=4,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a> )
File ~/chroma/chroma/models/chroma.py:236, in Chroma.sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, design_ban_S, design_method, design_selection, design_t, temperature_S, temperature_chi, top_p_S, regularization, potts_mcmc_depth, potts_proposal, potts_symmetry_order, verbose)
233 design_kwargs = {k: input_args[k] for k in input_args if k in design_keys}
235 # Perform Sampling
--> 236 sample_output = self._sample(**backbone_kwargs)
238 if full_output:
239 protein_sample, output_dictionary = sample_output
File ~/chroma/chroma/models/chroma.py:381, in Chroma._sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, **kwargs)
373 outs = {
374 \"C\": torch.tensor([], device=X_unc.device),
375 \"X_sample\": torch.tensor([], device=X_unc.device),
(...)
378 \"Xunc_trajectory\": [torch.tensor([], device=X_unc.device) for i in range(steps)],
379 }
380 for b in range(num_batches):
--> 381 outs_ = self.backbone_network.sample_sde(
382 C_unc[b * batch_size : (b + 1) * batch_size],
383 X_init=X_unc[b * batch_size : (b + 1) * batch_size],
384 conditioner=conditioner,
385 tspan=tspan,
386 langevin_isothermal=langevin_isothermal,
387 integrate_func=integrate_func,
388 sde_func=sde_func,
389 langevin_factor=langevin_factor,
390 inverse_temperature=inverse_temperature,
391 N=steps,
392 initialize_noise=initialize_noise,
393 **kwargs,
394 )
395 outs[\"C\"] = torch.cat([outs[\"C\"], outs_[\"C\"]], dim=0)
396 outs[\"X_sample\"] = torch.cat([outs[\"X_sample\"], outs_[\"X_sample\"]], dim=0)
File ~/chroma/chroma/models/graph_backbone.py:187, in GraphBackbone.__init__.<locals>.<lambda>(C, **kwargs)
185 # Wrap sampling functions
186 _X0_func = lambda X, C, t: self.denoise(X, C, t)
--> 187 self.sample_sde = lambda C, **kwargs: self.noise_perturb.sample_sde(
188 _X0_func, C, **kwargs
189 )
190 self.sample_baoab = lambda C, **kwargs: self.noise_perturb.sample_baoab(
191 _X0_func, C, **kwargs
192 )
193 self.sample_ode = lambda C, **kwargs: self.noise_perturb.sample_ode(
194 _X0_func, C, **kwargs
195 )
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/chroma/chroma/layers/structure/diffusion.py:1208, in DiffusionChainCov.sample_sde(self, X0_func, C, X_init, conditioner, N, tspan, inverse_temperature, langevin_factor, langevin_isothermal, sde_func, integrate_func, initialize_noise, remap_time, remove_drift_translate, remove_noise_translate, align_X0)
1206 U_test = 0.0
1207 t_test = torch.tensor([0.0], device=X_init.device)
-> 1208 _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
1209 else:
1210 Ct = C
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
112 if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
113 raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)
File ~/chroma/chroma/layers/structure/conditioners.py:241, in SubsequenceConditioner.forward(self, X, C, O, U, t)
239 if self.mask_condition is not None:
240 priority = 1.0 - self.mask_condition
--> 241 out = self.design_model(X_input, C, self.S_condition, t, priority=priority)
242 logp_S = out[\"logp_S\"]
244 if self.mask_condition is not None:
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
112 if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
113 raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)
File ~/chroma/chroma/models/graph_design.py:431, in GraphDesign.forward(self, X, C, S, t, sample_noise, permute_idx, priority)
429 logp_S_potts = None
430 if self.kwargs[\"predict_S_potts\"]:
--> 431 logp_S_potts = self.decoder_S_potts.loss(
432 S, node_h, edge_h, edge_idx, mask_i, mask_ij
433 )
435 # Sample random permutations and build autoregressive mask
436 if permute_idx is None:
File ~/chroma/chroma/layers/structure/potts.py:542, in GraphPotts.loss(self, S, node_h, edge_h, edge_idx, mask_i, mask_ij)
539 h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij)
541 # Log composite likelihood
--> 542 logp_ij, mask_p_ij = self.log_composite_likelihood(
543 S,
544 h,
545 J,
546 edge_idx,
547 mask_i,
548 mask_ij,
549 smoothing_alpha=self.label_smoothing if self.training else 0.0,
550 )
552 # Map into approximate local likelihoods
553 logp_i = (
554 mask_i
555 * torch.sum(mask_p_ij * logp_ij, dim=-1)
556 / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3)
557 )
File ~/chroma/chroma/layers/structure/potts.py:458, in GraphPotts.log_composite_likelihood(self, S, h, J, edge_idx, mask_i, mask_ij, smoothing_alpha)
454 num_batch, num_residues, num_k, num_states, _ = list(J.size())
456 # Gather J clamped at j
457 # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i]
--> 458 S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx)
459 S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1)
460 # (B,i,j,A_i)
File ~/chroma/chroma/layers/graph.py:677, in collect_neighbors(node_h, edge_idx)
675 idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
676 idx_flat = idx_flat.expand(-1, -1, num_features)
--> 677 neighbor_h = torch.gather(node_h, 1, idx_flat)
678 neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
679 return neighbor_h
RuntimeError: Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1"
}
Yeah, not all conditioners have been tested for batched sampling. Will need to figure out this. Thanks for flagging!
This looks like a really great project, thank you for making it!
I encountered some inconveniences when trying to generate a large number of samples with
protein_init
not set toNone
and I thought my fixes might be useful for other people.First, there is a bug in the code right now that leads to only one sample being generated when
protein_init
is notNone
, independently of thesamples
parameter. Here I expandX_unc
,C_unc
andS_unc
along the batch dimension before passing them to the model to fix that. Second, since there is no batching during sampling, generating a large number of samples is inconvenient. In this PR I added abatch_size
parameter to make it possible to generate a larger number of outputs with one command.