huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.64k stars 25.52k forks source link

BLIP2 hangs after loading shards, no errors #22064

Closed thely closed 1 year ago

thely commented 1 year ago

System Info

python: 3.9.13 torch: 1.12.0+cu113 transformers: 4.27.0.dev0

Note: I'm on an HPC, running everything through SLURM. I'm not privy to what kind of CPU I'm using. CPU: Unknown GPU: NVIDIA A100 GPU memory: 40GB

Who can help?

@ArthurZucker @younesbelkada @sgugger

Information

Tasks

Reproduction

The original version of this code was borrowed from the zero-shot inference article on Medium, then expanded for a larger set of images.

> train2
> > company1
> > > image1.png
> > > image2.png

> > company2
> > > image1.png
> > > image2.png
> > > etc

What's happening is that the checkpoint shards for the model will load, and then hang, forever, on this line:

Loading checkpoint shards: 100%|██████████| 2/2 [00:25<00:00, 12.90s/it]

I'm not training or fine-tuning, just trying to run normal inference. It won't error out, either, nor will I get some kind of OOM error from SLURM. It just stays forever. Running allocations (which tracks how many hours I've used for jobs sent via SLURM to the HPC) also isn't incrementing time for these jobs at all, which makes me think there's some error I can't see. (Though if I check squeue, the time on the job itself is still ticking up, but that time isn't getting applied to my overall time limit somehow.)

I can't tell if this is because of some secret OOM error, because I'm working with about 7GB of image files. I attempted batch inference a few weeks ago, but it wasn't working at the time.

The single image version of the BLIP2 inference code is working correctly, though, and typically finishes before I can even tail -f the log file. I have both pieces of code below for reference.

Code that's not working first, inferencing a folder full of folders full of images:

from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
from PIL import Image
import glob
import json

print("finished imports")

# big list of brand names, printing keys to make sure anything works before the shards make everything hang
folder = ".../inputs/"
brands = {}
with open(folder + "full_brands.json") as jsonf:
        brands = json.load(jsonf)

keys = sorted(brands.keys())
print(keys[0:10])

cachedir = ".../hfcache"

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir=cachedir)
print("processor good")

try:
        model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, cache_dir=cachedir)
# neither the below error, nor the else statement will ever print. we hang here.
except err:
        print(err)
else:
        print("blip ready")

print("model loaded")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

retval = []

# I haven't reached this section of the code in about a day, but it's here for reference in case this is what's making things hang
for slug in keys:
        print(slug)
        bname = brands[slug]["name"]
        image_files = glob.glob(folder + "/train2/" + slug + "/*.png")
        images = []

        for x in range(len(image_files)):
                try:
                        images.append(Image.open(image_files[x]).convert("RGBA"))
                except:
                        print("image non-functional")

        for i in range(len(images)):
                print(".", end="")
                image = images[i]
                prompt = "an image of " + bname + " with"
                inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

                generated_ids = model.generate(**inputs, max_new_tokens=20)
                generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
                desc = prompt + " " + generated_text
                retval.append({"file_name": image_files[i], "text": desc})

                print(desc)

with open(folder + "blip_output.json", "w") as jsonf:
        json.dump(retval, jsonf, indent=2)

Code that is working second, inference on a single image:

import requests
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch

path = ".../newyorker.jpg"
image = Image.open(path).convert('RGBA')
print(image)
cachedir = ".../hfcache"

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir=cachedir)
print("processor loaded")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, cache_dir=cachedir)

print("model loaded")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("cuda invoked")

inputs = processor(image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

Expected behavior

I'd think BLIP/transformers would either error out or continue to the rest of the code. I wish I knew what was going on.

As another point of reference, the top chunk of code was working yesterday on torch 1.10.0 and transformers 4.26.1, but between then and now, something about torch got updated such that torch 1.10.0 wasn't working with the A100 GPUs. (I was getting the "no binary exists for this device" error.) When I had to move up to torch 1.12.0, Blip2ForConditionalGeneration no longer existed, so I had to bump up to transformers 4.27.0.dev0, and here we are now.

But the smaller code is still working. So I don't know what the impact of all those images is on the file itself, but since the code never reaches the point where it could load the images, I don't understand how this is happening.

younesbelkada commented 1 year ago

hi @thely Thanks for the issue, it might be indeed a CPU related issue but this is hard to tell , I'd give a try by loading a model with low_cpu_mem_usage=True:

model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, cache_dir=cachedir, low_cpu_mem_usage=True)

I would also give it a try with accelerate + 8-bit since it enables loading the model with less memory requirements: First:

pip install accelerate bitsandbytes

Then:

model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True)
thely commented 1 year ago

@younesbelkada I don't know what about your comment helped me, but it helped me realize the problem wasn't in transformers, or BLIP2.

In the initial run of this code on torch 1.10.0, something about the config (Pillow? torch? python?) was printing lines regularly as the code progressed. After the change to torch 1.12.0, which changed both the active Python version from 3.8.x to 3.9.x and the Pillow version from 8.x to 9.x, I wasn't shown any print statements until all the activity had completed – image loading, running through BLIP, output, etc. So I guess it wasn't hanging, I just didn't get to know that anything was happening until the very end. Not sure if it's something about Python 3.9.x scheduling print statements differently, but I'm leaving this here in case it helps someone else.

For my sanity, I fixed it by running through 100 folders at a time.

younesbelkada commented 1 year ago

Awesome! Thank you for the update @thely !