hkchengrex / XMem

[ECCV 2022] XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model
https://hkchengrex.com/XMem/
MIT License
1.72k stars 191 forks source link

Memory distribution algorithms #106

Closed MaxTeselkin closed 1 year ago

MaxTeselkin commented 1 year ago

Hi! Thanks for great architecture! I am coding a web application based on XMem and I am facing CUDA out of memory error when tracking on 40+ frames. My GPU has 24 GB memory. My video resolution is quite high (1920 x 1080), so I tried interpolating input mask and frames to lower resolution. It helped (without resizing I was able to track on only 5 frames, with resizing I am able to track on ~35 frames now). But it is written in your paper that it is possible to track even on 1000 frames. Can you please share memory distribution algorithms to make tracking on big amount of frames available?

Here is my code for loading model on GPU:

def load_on_device(
        self,
        model_dir: str,
        device: Literal["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"] = "cpu",
    ):
        self.device = torch.device(device)
        # disable gradient calculation
        torch.set_grad_enabled(False)
        # define model configuration
        self.config = {
            "top_k": 30,
            "mem_every": 5,
            "deep_update_every": -1,
            "enable_long_term": True,
            "enable_long_term_count_usage": True,
            "num_prototypes": 128,
            "min_mid_term_frames": 5,
            "max_mid_term_frames": 10,
            "max_long_term_elements": 10000,
        }
        # build model
        self.model = XMem(self.config, weights_location_path, map_location=self.device).eval()
        self.model = self.model.to(self.device)

Here is my code for prediction:

def predict(
            self,
            frames: List[np.ndarray],
            input_mask: np.ndarray,
    ):
        # object IDs should be consecutive and start from 1 (0 represents the background)
        num_objects = len(np.unique(input_mask))
        # load processor
        processor = InferenceCore(self.model, config=self.config)
        processor.set_all_labels(range(1, num_objects))
        # resize input mask
        original_width, original_height = frames[0].shape[1], frames[0].shape[0]
        scaler = min(original_width, original_height) / 480
        resized_width = int(original_width / scaler)
        resized_height = int(original_height / scaler)
        input_mask = torch.from_numpy(input_mask)
        input_mask = torch.unsqueeze(input_mask, 0)
        input_mask = torch.unsqueeze(input_mask, 0)
        input_mask = torch.nn.functional.interpolate(input_mask, (resized_height, resized_width), mode="nearest")
        input_mask = input_mask.squeeze().numpy()
        results = []
        # track input objects' masks
        with torch.cuda.amp.autocast(enabled=True):
            for i, frame in enumerate(frames):
                # preprocess frame
                frame = frame.transpose(2, 0, 1)
                frame = torch.from_numpy(frame)
                frame = torch.unsqueeze(frame, 0)
                frame = torch.nn.functional.interpolate(frame, (resized_height, resized_width), mode="nearest")
                frame = frame.squeeze().numpy()
                frame = torch.from_numpy(frame).float().to(self.device) / 255
                frame = im_normalization(frame)
                # inference model on specific frame
                if i == 0:
                    # preprocess input mask
                    input_mask = index_numpy_to_one_hot_torch(input_mask, num_objects)
                    input_mask = input_mask[1:]
                    input_mask = input_mask.to(self.device)
                    prediction = processor.step(frame, input_mask)
                else:
                    prediction = processor.step(frame)
                # postprocess prediction
                prediction = torch_prob_to_numpy_mask(prediction)
                prediction = torch.from_numpy(prediction)
                prediction = torch.unsqueeze(prediction, 0)
                prediction = torch.unsqueeze(prediction, 0)
                prediction = torch.nn.functional.interpolate(prediction, (original_height, original_width), mode="nearest")
                prediction = prediction.squeeze().numpy()
                # save predicted mask
                results.append(prediction)
        return results

What am I doing wrong? Why I am unable to track on more than 35 frames even with 24 GB GPU?

hkchengrex commented 1 year ago

I don't see anything wrong with your code at a glance except appending the final predictions to a list (but it should use CPU memory). Memory leaking is likely happening somewhere and it is probably not in InferenceCore (you can test with eval.py for this). Can you narrow down the potential lines where memory leaking happens (e.g., with https://stackoverflow.com/questions/58216000/get-total-amount-of-free-gpu-memory-and-available-using-pytorch)?

MaxTeselkin commented 1 year ago

Hi @hkchengrex, thanks for answer. I followed your advice and checked GPU memory usage on every iteration, and it is constantly growing with every frame.

I tried adding the following lines of code to my predict function to free GPU memory from frames that were already passed to model, but it gave no results:

# remove frame and mask from GPU
input_mask.cpu()
frame.cpu()

It is written in your paper that XMem has low memory usage. Maybe you have some memory optimization algorithms like LRU and I am not using them?

hkchengrex commented 1 year ago

If you are using InferenceCore with the config that you gave, you are already using the memory management algorithm. This is why I wonder if there are other lines of code that can potentially cause a memory leak.

MaxTeselkin commented 1 year ago

@hkchengrex No, honestly I send you all code which is used for running XMem. Right now I am resizing smaller size to 480 and larger size proportionally. I can try resizing to lower resolution, let's say 320, but it still does not answer the question why my XMem integration code turned out to be so memory consuming

MaxTeselkin commented 1 year ago

I also tried decreasing "min_mid_term_frames" and "max_mid_term_frames" to save memory as proposed in your eval.py script, but it didn't give me significant results

hkchengrex commented 1 year ago

@hkchengrex No, honestly I send you all code which is used for running XMem. Right now I am resizing smaller size to 480 and larger size proportionally. I can try resizing to lower resolution, let's say 320, but it still does not answers the question why my XMem integration code turned out to be so memory consuming

I mean this is not "runnable". WIth 24G you should be able to run 480p videos comfortably.

I also tried decreasing "min_mid_term_frames" and "max_mid_term_frames" to save memory as proposed in your eval.py script, but it didn't give me significant results

Should not be relevant.

One more thing -- how many objects are there?

MaxTeselkin commented 1 year ago

I have only 1 object. I added some lines of code to your tutorial colab to check GPU memory usage on every iteration and it is a lot lower than mine. But the code is almost the same. I am totally confused.

MaxTeselkin commented 1 year ago

Here is how my free GPU memory decreases with every single frame:

image
MaxTeselkin commented 1 year ago

While running your code in Colab shows significantly lower memory consumption:

image
MaxTeselkin commented 1 year ago

So every single frame takes almost 500 MB GPU memory for me, I have no idea why is it happening

hkchengrex commented 1 year ago

Then I'm afraid that it is from other parts of the code or the environment. Unfortunately, I cannot help with either of those.

hkchengrex commented 1 year ago

Actually, can it be that autograd is on? In eval.py we turn off autograd with something like torch.grad....

MaxTeselkin commented 1 year ago

@hkchengrex Yes, you are right. The cause of the problem was the fact that I am using torch.set_grad_enabled(False) in load_on_device function, but load_on_device and predict functions run in different threads, so even through I use torch.set_grad_enabled(False) in my load_on_device function, it has no effect on predict function since it is being used in a different thread. I was unaware of the fact that these functions are running in different threads because I didn't have full understanding of legacy code I was using for my web app. Now it all makes sense. Thank you for help and great neural network, good luck!