CASIA-IVA-Lab / FastSAM

Fast Segment Anything
GNU Affero General Public License v3.0
7.3k stars 681 forks source link

Memory leak issues when implementing fastSAM in a loop? #182

Open hanschanhs opened 9 months ago

hanschanhs commented 9 months ago

Hi,

Is anyone experiencing memory leak issues when implementing fastSAM in a loop?

I have a library of photos which I iterate through, and using fastSAM segment the photos and extracting only basic coordinates from the segments - the annotations are (no longer) plotted or saved.

Expected behaviour: Memory use stays the same throughout

Observed behaviour - small but noticable memory creep. Using tracemalloc, I often see the following building up mem use:

matplotlib version: 3.5.1

I have tried importing matplotlib.pyplot as plt after importing fastSAM

Help appreciated - thanks!

BrunoGeorgevich commented 9 months ago

Hello @hanschanhs, I had the same problem as you. Because of this, I decided to look for a solution and I achieved something interesting.

Scenario before my solution:

FastSAMLeak

Scenario after my solution:

FastSAMLeakSolved

In Figure 1, before I implemented my solution, it is possible to see that the RAM usage is increasing over time. In contrast, Figure 2 shows that the RAM reaches a value and remains stable.

The solution it was very simple:

1) Run the FastSAM model inside the torch.no_grad() context to avoid registering gradients during inference.

with torch.no_grad():
    image = input_data.get("image", None)

    if image is None:
        raise ValueError("Image is not provided")

    if self.__initialized is False:
        raise ValueError("Model is not initiated")

    everything_results = self.__model(
        image,
        device=self.__device,
        retina_masks=True,
        imgsz=1024,
        conf=0.8,
        iou=0.5,
    )
    prompt_process = FastSAMPrompt(
        image, everything_results, device=self.__device
    )
    ann = prompt_process.everything_prompt()
    return prompt_process, ann

2) Delete the FastSAMPrompt after processing to avoid retaining garbage.

output = prompt.plot_to_result(ann, retina=True)

del prompt
del ann
gc.collect()

return output

I hope this solution helps with your problem. Please let me know if it worked for you.

Best regards.

9527-csroad commented 9 months ago

Hello, @BrunoGeorgevich Could you show your initialize code, I use this model and it will increase memory continously.

BrunoGeorgevich commented 8 months ago

Hello, @BrunoGeorgevich Could you show your initialize code, I use this model and it will increase memory continously.

Hello @9527-csroad.

Here you can see the initialization and processing methods:

FastSAM initialization

from fastsam import FastSAM, FastSAMPrompt

class FastSamModule(AIModule):
    """This class is the implementation of the model FastSAM as an AIModule"""

    __model = None
    __initialized = False
    __device = "cuda:0"

    def initiate(self, model_path: str = "weights/FastSAM-x.pt") -> None:
        """
        Initializes the object by loading the FastSAM model from the specified path. If no path is provided, the default path is "weights/FastSAM-x.pt".

        :param model_path: A string representing the path to the FastSAM model file.
        :type model_path: str
        :return: None"""
        self.__model = FastSAM(model_path)
        self.__initialized = True

FastSAM Processing and Exhibition

@torch.no_grad()
    def process(self, input_data: dict) -> (FastSAMPrompt, Any):
        """
        The function processes an image using a model and returns a prompt process object and the
        annotations generated from the prompt process.

        :param input_data: Input data dictionary, which must contains the image to be processed.
        {
            "image": np.ndarray
            ...
        }
        :return: a tuple containing two values: `prompt_process` and `ann`.
        """
        with torch.no_grad():
            image = input_data.get("image", None)

            if image is None:
                raise ValueError("Image is not provided")

            if self.__initialized is False:
                raise ValueError("Model is not initiated")

            everything_results = self.__model(
                image,
                device=self.__device,
                retina_masks=True,
                imgsz=1024,
                conf=0.8,
                iou=0.5,
            )
            prompt_process = FastSAMPrompt(
                image, everything_results, device=self.__device
            )
            ann = prompt_process.everything_prompt()
            return prompt_process, ann

@torch.no_grad()
    def draw_results(
        self, input_data: dict, results: (FastSAMPrompt, Any)
    ) -> np.ndarray:
        """
        The function takes an image and a list of results, and returns the image with the prompt and
        annotation plotted on it, or the original image if there are no results.

        :param input_data: Input data dictionary, which must contains the image to be processed.
        {
            "image": np.ndarray
            ...
        }
        :param results: The `results` parameter is a tuple that contains two elements, the FastSAMPrompt and the annotations
        :type results: (FastSAMPrompt, Any)
        :return: an image with the segmented area drawn on it
        """
        image = input_data.get("image", None)

        if image is None:
            raise ValueError("Image is not provided")

        if results is None:
            return None

        prompt = results[0]
        ann = results[1]

        try:
            output = prompt.plot_to_result(ann, retina=True)
        except IndexError:
            output = image

        del prompt
        del ann
        gc.collect()
        torch.cuda.empty_cache()

        return output

I hope this helps you. Please let me know if I can assist you further.

Best regards.

9527-csroad commented 8 months ago

Hi @BrunoGeorgevich Thank you, Your code has inspired me and I found my problem. I save the everything_results in a dictionary which kept increasing the GPU memory. I have placed it to ram and the image size is also limited. Now, it's sovled. Thanks again, hope you have a nice day!

Best regards.

altave-gabriel-viana commented 2 months ago

Hi @BrunoGeorgevich,

Thank you for your code. It has helped reduce the high GPU consumption during inference with FastSAM. However, I'm still experiencing some instability when using FastSAM. The GPU usage fluctuates frequently between 4GB and 8GB. I am running several other models alongside FastSAM in a computer vision system. When I disable the FastSAM code, the GPU usage remains stable at 4GB. But with FastSAM enabled, the usage erratically spikes from 4GB to 8GB. My model initialization and garbage collector management are implemented as per your provided code. Do you have any idea why this is happening?

BrunoGeorgevich commented 2 months ago

Hi @BrunoGeorgevich,

Thank you for your code. It has helped reduce the high GPU consumption during inference with FastSAM. However, I'm still experiencing some instability when using FastSAM. The GPU usage fluctuates frequently between 4GB and 8GB. I am running several other models alongside FastSAM in a computer vision system. When I disable the FastSAM code, the GPU usage remains stable at 4GB. But with FastSAM enabled, the usage erratically spikes from 4GB to 8GB. My model initialization and garbage collector management are implemented as per your provided code. Do you have any idea why this is happening?

Hi @altave-gabriel-viana, how are you? I hope you're doing well. Would you be able to share some code or charts that better illustrate the problem you're encountering? This would help us understand and address your issue more effectively.

altave-gabriel-viana commented 2 months ago

Hi @BrunoGeorgevich, I am fine! Thanks for your response. How are you doing?

Here is a snippet of my code:

import cv2
import numpy as np
import torch
import gc
from FastSAM.fastsam import FastSAM, FastSAMPrompt
from utils.morphological_operations import closing_mask, opening_mask

class Segmenter:
    def __init__(self) -> None:
        self.model = FastSAM("FastSAM-x.pt")
        self.model.to("cuda:0")

    @staticmethod
    def _create_mask_based_on_annotations(image: np.ndarray, annotations: np.ndarray) -> np.ndarray:
        mask = np.zeros(image.shape[:2], dtype=np.uint8)

        for annotation in annotations:
            annotation = annotation.astype(np.uint8)
            annotation = closing_mask(annotation, 3)
            annotation = opening_mask(annotation, 8)

            mask[annotation == 1] = 255

        return mask

    def predict(
        self,
        image: np.ndarray,
        box_prompts: list[list[int]],
    ) -> np.ndarray:
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                everything_results = self.model(
                    rgb_image,
                    device="cuda:0",
                    retina_masks=True,
                    imgsz=1024,
                    conf=0.4,
                    iou=0.9,
                    verbose=False,
                )
                prompt_process = FastSAMPrompt(image, everything_results, device="cuda:0")

                annotations = prompt_process.box_prompt(bboxes=box_prompts)
                mask = self._create_mask_based_on_annotations(image, annotations)

                del everything_results
                del prompt_process
                del annotations
                gc.collect()
                torch.cuda.empty_cache()

                return mask

if __name__ == "__main__":
    segmenter = Segmenter()
    video_capture = cv2.VideoCapture("video.mp4")

    while video_capture.isOpened():
        ret, frame = video_capture.read()

        if not ret:
            break

        mask = segmenter.predict(frame, [[0, 0, 100, 100]])
        cv2.imshow("mask", mask)
        cv2.waitKey(1)

When I run the code above, the GPU memory becomes unstable, going from 2 to 3 GB as shown in the images below:

Screenshot from 2024-06-24 09-38-54 Screenshot from 2024-06-24 09-38-55

And the screenshots are of consecutive frames, I don't see what could be happening, because the model is just inferring and the GPU memory fluctuated a lot. But when I use other PyTorch models, such as VGGs and YOLO from ultralytics, I have stable and continuous GPU usage. Do you have any idea why this is happening?