bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
1.05k stars 205 forks source link

Where can I find the code of attention mask for generative pre-training #94

Open Haonan917 opened 1 year ago

Haonan917 commented 1 year ago

Thank you for your great work! But it seems that the latest code didn't implement your special design of attention mask during pre-train?

def generate(
    self,
    cell_emb: Tensor,
    src: Tensor,
    values: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    gen_iters: int = 1,
    batch_labels: Optional[Tensor] = None,  # (batch,)
) -> Tensor:
    """
    Args:
        cell_emb(:obj:`Tensor`): shape (batch, embsize)
        src(:obj:`Tensor`): shape (batch, seq_len)
        values(:obj:`Tensor`): shape (batch, seq_len), optional
        src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional
        gen_iters(:obj:`int`): number of generation iterations
        batch_labels(:obj:`Tensor`): shape (batch,), optional
    """
    # TODO: should have a tag indicate the generation mode
    # TODO: if gen_iters > 1, should have a tag indicate the current iteration
subercui commented 1 year ago

Hi, thank you for the question! Currently, the pretraining code is in the dev-temp branch. You may also find a training script here https://github.com/bowang-lab/scGPT/blob/dev-temp/examples/pretrain.py , which will guide through the process. For the usage of the attention mask for generation in inference, we'll release a specific tutorial for cell generation soon.

xinyu-dev commented 11 months ago

This is great to hear. In addition to pretrain.py, would you be able to provide a slice of dataset to test the pertaining script? @subercui

DanielFLevine commented 10 months ago

@subercui Is the generative pretraining attention mask used in this script? I'm unable to locate where it's implemented in this branch. If I'm following the code correctly, the generative_forward function essentially uses the forward function of this class: https://github.com/bowang-lab/scGPT/blob/dev-temp/scgpt/model/flash_layers.py#L389 , but there doesn't seem to be anything special going on here.

Maybe I don't understand the masking procedure? I thought the attention mask needs to be generated on the fly, so I'm searching for a forward function that loops and unmasks genes based on likelihood scores or something similar. Is this not how it works?

olavdc commented 3 months ago

Hi @DanielFLevine,

I've come to the same conclusion as yours. Have you been able to figure this out since then ? Thanks in advance !