bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
1.05k stars 207 forks source link

Issues with tokenization in embedding generation script #268

Closed xqiu625 closed 1 week ago

xqiu625 commented 1 month ago

I'm trying to generate embeddings from scGPT for my single-cell data but encountering tokenization issues. Here's my scenario and the errors I'm facing:

  1. Initial Setup:

    adata = ad.read_h5ad("single_cell_human_pbmc_counts.h5ad")
    model_dir = Path("scGPT_human/")
    vocab_file = model_dir / "vocab.json"
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
  2. I've tried different approaches to handle the vocabulary:

Approach 1: Using raw JSON dictionary

with open(vocab_file) as f:
    vocab = json.load(f)

Approach 2: Using custom vocabulary class

class CustomVocab:
    def __init__(self):
        self.stoi = {}
        self.itos = []

    @classmethod
    def from_file(cls, vocab_file):
        vocab = cls()
        with open(vocab_file, "r") as f:
            token2idx = json.load(f)
            if isinstance(next(iter(token2idx.values())), str):
                token2idx = {v: int(k) for k, v in token2idx.items()}
        vocab.stoi = token2idx
        vocab.itos = [""] * (max(token2idx.values()) + 1)
        for token, idx in token2idx.items():
            vocab.itos[idx] = token
        return vocab
  1. The error occurs during tokenization:

    Error: TypeError: 'int' object is not subscriptable

    at this line in gene_tokenizer.py:

    cls_id = vocab[cls_token]
  2. I'm using this embedding generation function based on the example of Tutorial_Reference_Mapping_dataset.ipynb:

    def get_batch_cell_embeddings(
    adata,
    cell_embedding_mode: str = "cls",
    model=None,
    vocab=None,
    max_length=1200,
    model_configs=None,
    gene_ids=None,
    use_batch_labels=False,
    ) -> np.ndarray:
    """
    Get the cell embeddings for a batch of cells.
    
    Args:
        adata (AnnData): The AnnData object.
        gene_embs (np.ndarray): The gene embeddings, shape (len(vocab), d_emb).
        count_matrix (np.ndarray): The count matrix.
    
    Returns:
        np.ndarray: The cell embeddings.
    """
    count_matrix = (
        adata.layers["counts"]
        if isinstance(adata.layers["counts"], np.ndarray)
        else adata.layers["counts"].A
    )
    
    # gene vocabulary ids
    if gene_ids is None:
        gene_ids = np.array(adata.var["id_in_vocab"])
        assert np.all(gene_ids >= 0)
    
    if use_batch_labels:
        batch_ids = np.array(adata.obs["batch_id"].tolist())
    
    elif cell_embedding_mode == "cls":
        tokenized_all = tokenize_and_pad_batch(
            count_matrix,
            gene_ids,
            max_len=max_length,
            vocab=vocab,
            pad_token=model_configs["pad_token"],
            pad_value=model_configs["pad_value"],
            append_cls=True,  # append <cls> token at the beginning
            include_zero_gene=False,
        )
        all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
        src_key_padding_mask = all_gene_ids.eq(vocab[model_configs["pad_token"]])
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
            cell_embeddings = model.encode_batch(
                all_gene_ids,
                all_values.float(),
                src_key_padding_mask=src_key_padding_mask,
                batch_size=64,
                batch_labels=None,
                time_step=0,
                return_np=True,
            )
        cell_embeddings = cell_embeddings / np.linalg.norm(
            cell_embeddings, axis=1, keepdims=True
        )
    else:
        raise ValueError(f"Unknown cell embedding mode: {cell_embedding_mode}")
    return cell_embeddings

Base on those here are my questions:

  1. What is the correct way to handle the vocabulary for tokenization?
  2. Should we be using a specific vocabulary class from scGPT instead of creating a custom one?
  3. Is there an example of the correct vocabulary format and usage for embedding generation?

Here is the environ I am using:

Let me know if you'd like me to provide any additional information or test any specific solutions.

subercui commented 3 weeks ago

Hi, thank you for the question! We recommend using our customed GeneVocab class, it is here https://github.com/bowang-lab/scGPT/blob/7301b51a72f5db321fccebb51bc4dd1380d99023/scgpt/tokenizer/gene_tokenizer.py#L20

One usecase can be found her in the cell_emb.py, the vocab is also loaded in similar fashions in the tutorial notebooks.

https://github.com/bowang-lab/scGPT/blob/7301b51a72f5db321fccebb51bc4dd1380d99023/scgpt/tasks/cell_emb.py#L208