facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Model consumes too much computing resources. #214

Closed walt676 closed 2 years ago

walt676 commented 2 years ago

Hello ESM team, I really appreciate you sharing the code for protein language modeling.

However I can not use the model esm1_t6_43M_UR50S which has the least params, whatever I use gpu or cpu. My machine has 32g of memory, and the graphics card 1080ti with 11g video memory.

Only when batch size is set to 4 I can get it to work normally, is it normal?

My code is as follows:

class ProteinEmbedding(nn.Module):
    def __init__(self, model_path="models/esm1_t6_43M_UR50S.pt",
                 repr_layer=6) -> None:
        super(ProteinEmbedding, self).__init__()
        self.repr_layer = repr_layer
        self.model, alphabet = esm.pretrained.load_model_and_alphabet(model_path)

        self.batch_converter = BatchConverter(alphabet=alphabet)

    def forward(self, sequences):
        batch_tokens = self.batch_converter(sequences)

        with torch.no_grad():
            token_repr = self.model(batch_tokens, repr_layers=[self.repr_layer])["representations"][self.repr_layer]
        seqs_repr = []
        for i, seq in enumerate(sequences):
            seqs_repr.append(token_repr[i, 1:len(seq) + 1].mean(0))
        return seqs_repr

The input sequences of the forward function in the code are raw protein sequences without label, consequently I rewrite the BatchConverter making it only returns tokens.

In addition, I would like to ask if token_repr[:, 0] is the <CLS> representation in your code?

tomsercu commented 2 years ago

The example is missing some detail to be a MWE, but it looks like you are trying to feed a large dataset all at once through the model. You can look at https://github.com/facebookresearch/esm/blob/main/scripts/extract.py#L87 for an example how to do this in a batched fashion.

Yes CLS/BOS token is at index 0, see L124 of the same script. But beware, this token was not supervised in any way

(accidental close - feel free to close if this ansewrs your question)

walt676 commented 2 years ago

Hi Tom, Thank you for your help, But actually I feed one batch of dataset at once.

  1. May I ask what is the average length of the protein sequence and the batch size you use when training the model, and the corresponding required video memory size?
  2. If I wish to predict correlations between different protein sequences, what layer output of each model should I use as the representation? For example, for esm1_t6_43M_UR50S, what should I set repr_layers to?
tomsercu commented 2 years ago
  1. if you want a single representation for the sequence, you may want to use the mean over the sequence For 1. if you provide a minimum working example and error message you see we can take a look. 6layer model is really small and should be able to process 10s of seqs of max length 1024
walt676 commented 2 years ago

Thank you for your prompt reply! For Q1,the data format in the training set is as follows:

MAFTFAAFCYMLALLLTAALIFFAIWHIIAFDELKTDYKNPIDQCNTLNPLVLPEYLIHAFFCVMFLCAAEWLTLGLNMPLLAYHIWRYMSRPVMSGPGLYDPTTIMNADILAYCQKEGWCKLAFYLLAFFYYLYGMIYVLVSS MAAAAGRLLWSSVARHASAISRSISASTVLRPVASRRTCLTDILWSASAQGKSAFSTSSSFHTPAVTQHAPYFKGTAVVNGEFKELSLDDFKGKYLVLFFYPLDFTFVCPTEIVAFSDKANEFHDVNCEVVAVSVDSHFSHLAWINTPRKNGGLGHMNITLLSDITKQISRDYGVLLESAGIALRGLFIIDPNGVVKHLSVNDLPVGRSVEETLRLVKAFQFVETHGEVCPANWTPESPTIKPSPTASKEYFEKVHQ 269

As you see, I tried to predict a label (269 in the above case) through two protein sequences. Here is my training code:

    torch.manual_seed(511)

    device = LinkConfig.device
    emb_model, alphabet = esm.pretrained.load_model_and_alphabet(LinkConfig.emb_path)
    emb_model.eval()
    emb_model.to(device)

    batch_converter = LinkBatchConverter(alphabet)

    train_dataset = ProteinLinkDataset.from_file(LinkConfig.train_dataset_path)
    train_loader = DataLoader(
        train_dataset, collate_fn=batch_converter,
        shuffle=True,
        batch_size=LinkConfig.batch_size
    )

    print(f"Read {LinkConfig.train_dataset_path} with {len(train_dataset)} sequences")
    print(f"Read {LinkConfig.test_dataset_path} with {len(test_dataset)} sequences")

    pred_model = ProteinLinkPredictor()
    pred_model.to(device)

    loss_fn = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(pred_model.parameters(), lr=LinkConfig.lr)

    len_dataloader = len(train_loader)
    batch_size = LinkConfig.batch_size

    for epoch in range(LinkConfig.epoch_num):
        pred_model.train()
        difference = 0
        difference_num = 0
        for i, (labels, first_proteins, second_proteins) in enumerate(train_loader):
            first_embed = get_pretrained_embedding(emb_model, first_proteins)
            second_embed = get_pretrained_embedding(emb_model, second_proteins)

            pred_model.zero_grad()

            outputs = pred_model(first_embed, second_embed)

            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            pred = outputs.data.max(1)[1]
            difference += torch.abs(labels-pred).sum(dim=0).item()
            difference_num += batch_size
            sys.stdout.write('\r epoch: %d, [iter %d / all %d] loss %f avg difference %d' \
                             % (epoch, i + 1, len_dataloader, loss.data.cpu().item(),
                                difference // difference_num))
            sys.stdout.flush()

def get_pretrained_embedding(model, inputs):
    assert LinkConfig.final_repr in LinkConfig.repr_layers
    with torch.no_grad():
        token_repr = model(inputs, repr_layers=LinkConfig.repr_layers)["representations"]
        token_repr = token_repr[LinkConfig.final_repr]
        batch_size, _, embed_size = token_repr.size()
        seqs_repr = torch.empty((batch_size, embed_size), dtype=torch.float32).to(LinkConfig.device)
        for i, seq in enumerate(inputs):
            seqs_repr[i] = token_repr[i, 1:len(seq) + 1].mean(0)
    return seqs_repr
class ProteinLinkDataset(Dataset):
    def __init__(self, labels, first_proteins, second_proteins):
        super(ProteinLinkDataset).__init__()
        self.labels = labels
        self.first_proteins = first_proteins
        self.second_proteins = second_proteins

    @classmethod
    def from_file(cls, file_path):
        first_proteins = []
        second_proteins = []
        labels = []

        with open(file_path, "r") as f:
            for line in f:
                line = line.strip("\n")
                first, second, label = line.split(" ")
                first_proteins.append(first)
                second_proteins.append(second)
                labels.append(label)

        return cls(labels, first_proteins, second_proteins)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index: int):
        return self.labels[index], self.first_proteins[index], self.second_proteins[index]
class LinkBatchConverter(object):
    def __init__(self, alphabet):
        self.alphabet = alphabet

    def __call__(self, raw_batch):
        batch_size = len(raw_batch)
        batch_labels, first_str_list, second_str_list = zip(*raw_batch)
        first_encoded_list = [self.alphabet.encode(seq_str) for seq_str in first_str_list]
        second_encoded_list = [self.alphabet.encode(seq_str) for seq_str in second_str_list]
        first_max_len = max(len(seq_encoded) for seq_encoded in first_encoded_list)
        second_max_len = max(len(seq_encoded) for seq_encoded in second_encoded_list) 

        first_tokens = torch.empty(
            (
                batch_size,
                first_max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos)
            ),
            dtype=torch.int64,
        )

        second_tokens = torch.empty(
            (
                batch_size,
                second_max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos)
            ),
            dtype=torch.int64
        )

        first_tokens.fill_(self.alphabet.padding_idx)
        second_tokens.fill_(self.alphabet.padding_idx)

        labels = []
        # res_strs = []
        for i, (label, fir_encoded, sec_encoded) in enumerate(
            zip(batch_labels, first_encoded_list, second_encoded_list)
        ):
            labels.append(int(label))
            if self.alphabet.prepend_bos:
                first_tokens[i, 0], second_tokens[i, 0] = self.alphabet.cls_idx, self.alphabet.cls_idx

            first_tokens[
                i,
                int(self.alphabet.prepend_bos): len(fir_encoded) + int(self.alphabet.prepend_bos),
            ] = torch.tensor(fir_encoded, dtype=torch.int64)
            second_tokens[
                i,
                int(self.alphabet.prepend_bos): len(sec_encoded) + int(self.alphabet.prepend_bos),
            ] = torch.tensor(sec_encoded, dtype=torch.int64)
            if self.alphabet.append_eos:
                first_tokens[i, len(fir_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
                second_tokens[i, len(sec_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx

        return torch.tensor(labels, dtype=torch.int64).to(DefaultConfig.device), \
               first_tokens.to(DefaultConfig.device), \
               second_tokens.to(DefaultConfig.device)

I think I feed a batch of data at once, and I can run the code with batch size 6. But if I do not use with torch.nograd() on the esm1_t6_43M_UR50S model and update its parameters at the same time, even if the batch size is 1, it will report insufficient video memory error RuntimeError: CUDA out of memory. I wonder if my code is wrong?

tomsercu commented 2 years ago

Hmm I don't see an obvious issue with the code. In principle it should be doable for small batches (batch size 1 and seq length < 1024 -- or are you feeding in very long sequences?) Back of envelope for esm1_t6_43M_UR50S:

To debug I'd check with nvidia-smi a couple things:

  1. is something else running on your GPU and occupying memory?
  2. is memory usage growing during your training loop? If yes - find the memory leak where your pytorch graph isn't cleared, something is holding on to one of the outputs of the model?
walt676 commented 2 years ago

Thank you so much for taking the time to review and answer my question, it helped a lot!