NVIDIA / bionemo-framework

BioNeMo Framework: For building and adapting AI models in drug discovery at scale
https://nvidia.github.io/bionemo-framework/
Other
138 stars 17 forks source link

Enable extraction of gene embeddings from geneformer (averaging of gene embeddings across all cells) #452

Open jstjohn opened 20 hours ago

jstjohn commented 20 hours ago

A potential design:

  1. add an argparse option for --num-layers-override in infer.py https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L235 with a default of None.
  2. Add logic to https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L34 where if the override is unset, nothing different happens
  3. If the override is set we need to do two things to make it impact the model:
    1. import this thing: https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py#L105
    2. add override_parent_fields=['num_layers'] + OVERRIDE_BIOBERT_CONFIG_DEFAULTS to the config_class (around here https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L116) but only if the user set num_layers_override != None. This communicates to the checkpoint loader to not pull this field out of the trained model config in the checkpoint, and instead use the user supplied option for this field.
    3. also add num_layers=num_layers_override to the config around that point, but again only if the user set this to not None.

What will happen then is the model will be initialized with the user requested num layers rather than the num_layers it was originally trained with. So if you want to remove the last layer and get the inference results from that second to last layer, and you know the model was trained with 6 layers, then you could set --num-layers-override 5 and you would get back a 5 layer model with that last layer left off.

Side note: These steps are generally how you would override any setting in the loaded model. This pattern can be used for fine-tuning as well as inference if you want to change things about the model when you load it. Note that in the fine-tuning case, not here, if you add a new layer you also need to communicate to the checkpoint loader to not look for that new layer in the checkpoint, otherwise you get a confusing looking error about that layer not being found at checkpoint load time.

skothenhill-nv commented 19 hours ago

More context, this is about getting 'gene embeddings' from geneformer. Right now we can pull the hiddens from the last layer, but will need to be able to pull them from an arbitrary embedding layer:

Our inference code:

https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L38

Description of the problem:

For each single cell transcriptome presented to Geneformer, the model embeds each gene into a 256-dimensional space that encodes the gene’s characteristics specific to the context of that cell. Contextual Geneformer gene embeddings are extracted as the hidden state weights for the 256 embedding dimensions for each gene within the given single cell transcriptome evaluated by forward pass through the Geneformer model. Gene embeddings analyzed in this study were extracted from the second to last layer of the models as the final layer is known to encompass features more directly related to the learning objective prediction while the second to last layer is a more generalizable representation.

(The second to last layer is handled by @jstjohn 's description above.

Reference Geneformer huggingface code

https://geneformer.readthedocs.io/en/latest/_modules/geneformer/emb_extractor.html#EmbExtractor We will need to be able to do this kind of aggregation in a way that is memory efficient, as well as ensure we have access to the cell labels from sc-memmap (if we want to aggregate by cell).

isabel-wilkinson commented 5 hours ago

Context from Birkan Gökbağ

Geneformer’s embedding extractions rely on the input datasets and for every cell, we obtain the generated embeddings of each cell’s expressed genes. o Gene Embedding Extraction: Gene embeddings are obtained by averaging the genes’ embeddings across all the cells. Since the architecture is NLP based, the ordering of the gene tokens will influence gene embeddings and therefore gene embedding in one cell may not be the same in another. The index location of the gene token will carry that embedding in the output. Make sure to index it with CLS token position in consideration (usually at index 0, so you may need to shift by 1).

§ i.e., Average gene embeddings across all cells

o Cell Embedding Extraction: As input is a cell (i.e., sorted series of tokens), the output is already a representation of the cell. These embeddings are averaged to represent the cell embedding (not including CLS token embedding).

§ I.e., Average embeddings of the input cell directly

o Optional aggregation by cell annotation: The previous analyses are applied per cell type annotation. Since the embedding process is limited to select annotation subsets, the embeddings will already be representative of the state only. Those embeddings are then aggregated using mean/median to represent the state. This is the scenario where you basically take the mean, or median, of the means.

isabel-wilkinson commented 5 hours ago

Ideally a test would be added as well