generatebio / chroma

A generative model for programmable protein design
Apache License 2.0
696 stars 90 forks source link

How to do deterministic sampling? #42

Closed Lifedecoder closed 9 months ago

Lifedecoder commented 10 months ago

Hi! Thanks for the great work!

Though I have a question about how to do deterministic sampling. I find the inference results are not the same when running a code multiple times (like I masked one portion of residues and redesign them). Even after I choose the max probability each node while doing autoregressive sampling, the results are not the same. I found that permutations are done when preprocessing the chains, this might be the reason. For potts sampling, the results also differs in different run times.

How can I do deterministic sampling, to choose the one with maximum probability (or to say the most possible chain that I want)?

aismail3-gnr8 commented 10 months ago

There seem to be two separate tasks here, making sampling deterministic and choosing the maximum probability residue during sequence design.

For the first, we've found that PyTorch 2 has deterministic versions of all the GPU operations we use. So, if you're willing to upgrade, you could try a snippet like this.

import os
import torch
import warnings
torch.use_deterministic_algorithms(True, warn_only=True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
chroma = Chroma() # should go before the warning filter
warnings.filterwarnings(action='default', message=r".*deterministic.*")

# set seed
seed = 0
torch.manual_seed(seed)
# sample from Chroma here
...

Independent of this, if you want to choose the most likely residue regardless, you could try lowering the temperature of the sequence sampling process via the temperature_S argument of Chroma.sample. Hope this helps!

Lifedecoder commented 10 months ago

Thanks for the quick response!

I think my question might be the latter case, that I want to generate the most likely results among probabilistic generations.

Sorry for the misunderstanding about previous description, when I was describing the question, I might be affected by autoregressive methods, in which case you choose tokens with maximum probability, and gets the most-likely residue chain, in that way it is both deterministic and most-wanted. This ''deterministic'' is not usually affected by random seed, but by how you sample from every node's predicted distribution.

So the question might become, how to choose the token with maximum probability at each node, to make sure that the autoregressive/potts produce best-likely results? I revised the code here to choose max prob at each sample, but the results are still not the same in multiple runs.

https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/models/graph_design.py#L1978

Could it due to the randomness by permutation?

I also tried to lower temperature to 1e-6 and 1e-9, results in multiple runs still differ (while I assume that the best results should stay the same?)

Or could it be other ways to sample the best result?

Thanks!!

aismail3-gnr8 commented 10 months ago

Yes, the permutation to perform the autoregressive sequence sampling is generated randomly. You can fix the random seed for this generation by setting chroma.design_network.traversal.deterministic = True, or alternatively by calling torch.manual_seed before generation. In the code above that uses the latter, everything ends up being deterministic, not just the sequence design.

Just to clarify, even when performing autoregressive sampling, Chroma lets you sample at arbitrary temperature, rather than choosing the most likely residue given those that are already selected. Choosing a low temperature_S, in combination with fixing the permutation, should give you a deterministic sampling of the best residues in a set order.

Lifedecoder commented 9 months ago

Thanks again for your response! I've tried recently and find it is the permutation who causes randomness. I mainly changed code here to maintain the sequence order of masked residues to stay the same using prioritized mask (while in original settings they were permutated randomly and placed last to be generated).

https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/models/graph_design.py#L2267C9-L2271C47