ZhengPeng7 / BiRefNet

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation
https://www.birefnet.top
MIT License
322 stars 28 forks source link

Inference on multiple images using batching and multiple GPUs #32

Closed abhishek0093 closed 2 weeks ago

abhishek0093 commented 2 weeks ago

Hi @ZhengPeng7 First of all, very thanks for this great work and open sourcing it.

Currently I’m running into an issue, and it would be very helpful if you can help me with. I have a big dataset of images that vary in sizes (some are high quality images like 4000x4000). I have noticed that for my case the results are better if I process it in 1024x1024 only if both dimensions are less than this 1024, otherwise processing in the same resolution is giving better results. So for most of the images I have to process in original dimensions which is large and varying. To implement processing in batches I have implemented my own custom dataloader that returns transformed image tensor . Everything is working fine If I’m resizing everything to 1024x1024 , however I face problems if I decide not to resize. As pytorch doesn’t allow to return variable length tensors , so I’m adding extra padding to make every image in batch equivalent to largest dimension present in the given batch and I can later remove the extra added padding to get original shaped transformed image. Now there are two problems at this stage :

  1. How do I pass these variable length tensors to the model ? Because I'm not aware if concatenate/stack tensors of different length possible. I can directly use padded ones, but as their size is equal to largest dimension image in the batch, I'm getting cuda OOM.
  2. I’m using aws for inference. Currently with single gpu I’m only able to process batch size of 4 (when every image is transformed to 1024x1024) and getting cuda OOM error above this. Do we know how we can use multiple gpus here ? I tried to run it on p3.8xlarge which is having 4 gpus, but only 1 gpu seems to be used while making inference. I tried to look into code, but I don't think it currently supports multiple gpus as I was thinking to use padded tensors directly with bigger gpu memory if possible.

Here is the code snippet for reference :

def custom_collate_fn(data):
        image_transformed, image_np = zip(*data)

        max_height = max([img.size(1) for img in image_np])
        max_width = max([img.size(2) for img in image_np])

        transformed_batch = [torch.nn.functional.pad(img, [int((max_width - img.size(2))/2), max_width - (img.size(2) + int((max_width - img.size(2))/2)), int((max_height - img.size(1))/2), max_height - (img.size(1) + int((max_height - img.size(1))/2))]) for img in image_transformed]

        return torch.stack(transformed_batch), max_width, max_height, [img.size() for img in image_np]

class BirefDataset(Dataset) : 
    def __init__(self, image_urls) : 
        self.image_urls = image_urls
        self.biref_transform = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
    def __len__(self):
        return len(self.image_urls)

    def __getitem__(self, idx):
        image_pil = getImage(self.image_urls[idx])
        resolution = (1024, 1024) if max(image_pil.size) < 1024 else image_pil.size[::-1]

        image_transformed = None
        if resolution == (1024, 1024) : # If I don't use this condition , and use self.biref_transform across everything it is working good. 
            image_transformed = self.biref_transform(image_pil)
        else : # Don't resize
            image_transformed = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])(image_pil)

        return image_transformed, pil_to_tensor_transform(image_pil)

biref_dataset =  BirefDataset(url_list, product_key_list)
dataloader_biref = DataLoader(dataset=biref_dataset,batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True, collate_fn=custom_collate_fn)

and later I'm calling this dataloader something like this :

def inference(dataloader_biref):
    for input_images, max_width, max_height, original_sizes in dataloader_biref:
        input_images = input_images.to(device)
        input_images_original_shape = []

        for idx in range(len(input_images)):
            img =  input_images[idx] 
            img_width, img_height = original_sizes[idx][2], original_sizes[idx][1]
            left_padding, top_padding = int((max_width - img_width)/2), int((max_height - img_height)/2)
            input_images_original_shape.append(transforms.functional.crop(img, top_padding, left_padding, img_height, img_width))

        input_images_original = # Combine input_images_original_shape into one batch

        with torch.no_grad():
            scaled_preds = biref_model(input_images_original)[-1].sigmoid()

inference(dataloader_biref)
abhishek0093 commented 2 weeks ago

I find an interesting workaround. Using custom sampler and BatchSampler in torch, I'm now able to achieve variable length batch sizes . So now I'm using batch_size of 4 for images which can be processed in 1024x1024 and for bigger images I'm using batch_size equals 1. Not sure if this is the best way, would love to hear if someone has some other way. Also would be great, if we could use multiple GPUs for processing, so that we can process bigger images also into batches.

ZhengPeng7 commented 2 weeks ago

Hi, Abhishek. Thanks for your interest and for writing so many texts about your problems. I'm always glad to help with these things, but too tired recently. Hence, I actually read your message last night but am taking time now to do it.

About the effectiveness of different sizes for training (1024x1024) and inference: I also did this kind of thing to see if I could obtain better results on benchmarks (see the keep_size in dataset.py). In my experiments, using the images in their original sizes for inference does not improve the results.

About dividing the large image into multiple patches to more GPUs for inference: in the existing pipeline, I think it's not possible since there are usually more than one object in the image -- salient object detection (SOD) is the first step to be conducted while cropping the image into patches will destroy the semantics for the target localization. However, in my mind, it's very possible to achieve this result with box prompt which doesn't rely on SOD in the whole image -- divide the box prompt and image simultaneously for separate inference on multiple GPUs.

If possible, I will try to train a BiRefNet with 2048x2048 training data or mixed-resolution data.

I might not answer all questions well. So, feel free to reply to me if you still have questions :)

abhishek0093 commented 2 weeks ago

@Zenpheng7, Thanks for replying. I do understand that you may have many commitments and I really appreciate you for taking out time and actively maintaining this repository as well as replying to people’s issues. Currently my issue is resolved as I'm able to solve batching thing with variable batch sizes/padding and for larger images simply resizing it to maximum possible in my GPU setup seems to easiest solution rn. Thankyou for help . I'm marking this issue as closed.

ZhengPeng7 commented 2 weeks ago

Okay, sorry for not replying to you in more detail. But I'm really happy to answer the questions you provide here, where things are listed very clearly.