kuleshov-group / caduceus

Bi-Directional Equivariant Long-Range DNA Sequence Modeling
Apache License 2.0
160 stars 23 forks source link

Inference? #49

Open leannmlindsey opened 1 month ago

leannmlindsey commented 1 month ago

Yair,

Have you released any information about how to use the fine tuned models for inference? L

yair-schiff commented 1 month ago

Hi @leannmlindsey, can you please clarify the question? I am not sure what you mean by how to use the models for inference. I guess it would depend on the task. There is code in this repo already for loading a pre-trained backbone and either fine-tuning a prediction head on top of embeddings or dumping the embeddings and training a separate model on top of those frozen embeddings.

leannmlindsey commented 6 days ago

I have several models that I have pre-trained using your pretraining instructions and then fine tuned using the genomic benchmark code, but replacing the datasets with my own datasets.

Unfortunately, I have been unable to find a way to properly load the checkpoint and configuration from the saved issues when trying to load trained Caduceus models for inference. Here is a list of the issues that I have encountered:

To be clear, I am trying to load a model from the checkpoint/last.ckpt and model_config.json files that are saved in the output directory for a fine-tuned model using the given code from genomic benchmark code.

  1. Architecture Mismatch:

    • The checkpoint contains weights for ~64 layers
    • The model name suggests a 4-layer architecture ("4L" in caduceus_ps_char_4k_d256_4L)
    • This fundamental mismatch prevents proper weight loading
  2. Configuration File Ambiguity:

    • The repository contains both config.json and model_config.json without documentation of their purposes
    • Neither appears to match the architecture of the saved checkpoint
  3. State Dictionary Key Mismatches:

    • Checkpoint keys include: "model.caduceus.backbone.layers.X.mixer.submodule.mamba_fwd"
    • Model expects: "caduceus.backbone.layers.X.mixer.mamba_fwd"
    • These naming inconsistencies prevent automatic weight loading
  4. Complex Infrastructure Requirements:

    • Need to handle dataset configuration even when just loading model for inference
    • Required understanding of PyTorch Lightning setup
    • No clear documentation for inference-only usage
    • Dependencies on training infrastructure for simple inference tasks

Feature Request: For better usability, please consider implementing a simple interface for loading trained models:

from caduceus import load_pretrained_model
model = load_pretrained_model(checkpoint_path)
model.eval()
predictions = model(input_sequence)

I appreciate the pre-training and fine-tuning code that you have released, I just would really like to be able to use the trained models for inference and have been unable to do this so far.

In the meantime, I will try your method of dumping the embeddings and training a separate model on top of the frozen embeddings.