kuleshov-group / caduceus

Bi-Directional Equivariant Long-Range DNA Sequence Modeling
Apache License 2.0
137 stars 14 forks source link

Unable to run the model on input sequence with length of 131072 #4

Closed qiaoqiaoLF closed 3 months ago

qiaoqiaoLF commented 3 months ago

I tried to run the model with input sequence length of 131072 on an A100 80G GPU, but "torch.cuda.OutOfMemoryError: CUDA out of memory" happened all the time. I have used the pretrained model from "https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16", batch_size=4. I also tried to use mix-precision, but the problem still exists.

yair-schiff commented 3 months ago

Can you please share a bit more detail about the command you’re using to launch this (are you using your own code or the scripts from this repo)?

qiaoqiaoLF commented 3 months ago

Hi yair-schiff!

I am using a batch size of four here and running on an A100 80G GPU. As mentioned in the article and huggingface model card, the model is pretrained on sequence of length 131k with a batch size of eight. Do you use some techniques like tensor parallel?

This is the code I produce the "torch.cuda.OutOfMemoryError: CUDA out of memory". You can simply run it using python.

import torch
# Load model directly
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16", trust_remote_code=True)
input_ids = torch.randint(7, 10, (4, 131072), dtype=torch.long)
input_ids = input_ids.cuda()
model = model.cuda()
outputs = model(input_ids)
yair-schiff commented 3 months ago

@qiaoqiaoLF, apologies for the confusion from the model card. Batch size was 8, but we used 8 GPUs for training, so each one had a batch size of 1. Can you try reducing your batch size and seeing if the issue persists.

qiaoqiaoLF commented 3 months ago

@qiaoqiaoLF, apologies for the confusion from the model card. Batch size was 8, but we used 8 GPUs for training, so each one had a batch size of 1. Can you try reducing your batch size and seeing if the issue persists.

It works! Thanks for explanation!