hkchengrex / Cutie

[CVPR 2024 Highlight] Putting the Object Back Into Video Object Segmentation
https://hkchengrex.com/Cutie/
MIT License
710 stars 70 forks source link

point supervision #12

Closed ret-1 closed 12 months ago

ret-1 commented 12 months ago

Thanks for your great work.

Since Cutie adopt point supervision for training to reduce memory requirements, I replaced my loss function with yours, but it didn't the video memory. I also replaced your loss function with the original XMem's and modified it (as shown below), and again found the memory to be almost identical. Do you have any idea what this is about?

image

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import defaultdict
from cutie.utils.tensor_utils import cls_to_one_hot

def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
    # mask: T*C*H*W
    # soft_gt: T*C*H*W
    # ignores the background
    mask = mask[:, 1:].flatten(start_dim=2)
    gt = soft_gt[:, 1:].float().flatten(start_dim=2)
    numerator = 2 * (mask * gt).sum(-1)
    denominator = mask.sum(-1) + gt.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum(0).mean()

# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
class BootstrappedCE(nn.Module):
    def __init__(self, start_warm, end_warm, top_p=0.15):
        super().__init__()

        self.start_warm = start_warm
        self.end_warm = end_warm
        self.top_p = top_p

    def forward(self, input, target, it):
        if it < self.start_warm:
            return F.cross_entropy(input, target), 1.0

        raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
        num_pixels = raw_loss.numel()

        if it > self.end_warm:
            this_p = self.top_p
        else:
            this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
        loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
        return loss.mean(), this_p

class LossComputerXMem:
    def __init__(self, config=None):
        super().__init__()
        # self.config = config
        self.bce = BootstrappedCE(20000, 70000)

    def compute(self, data, num_objects, it):
        losses = defaultdict(float)

        b, num_frames = data['rgb'].shape[:2]
        t_range = range(1, num_frames)

        for bi in range(b):
            logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], dim=0)
            cls_gt = data['cls_gt'][bi, 1:]  # remove gt for the first frame
            soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])

            loss, _ = self.bce(logits, soft_gt, it)
            losses[f'ce_loss'] += (loss / b)

            loss = dice_loss(logits.softmax(dim=1), soft_gt)
            losses[f'dice_loss'] += (loss / b)

            aux = [data[f'aux_{ti}'] for ti in t_range]
            if 'sensory_logits' in aux[0]:
                sensory_log = torch.stack(
                    [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)

                loss, _ = self.bce(sensory_log, F.interpolate(soft_gt, scale_factor=1/16), it)
                losses[f'aux_sensory_ce'] += (loss / b)

                loss = dice_loss(sensory_log.softmax(dim=1), F.interpolate(soft_gt, scale_factor=1/16))
                losses[f'aux_sensory_dice'] += (loss / b)

            if 'q_logits' in aux[0]:
                num_levels = aux[0]['q_logits'].shape[2]

                for l in range(num_levels):
                    query_log = torch.stack(
                        [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)

                    loss, _ = self.bce(query_log, F.interpolate(soft_gt, scale_factor=1 / 16), it)
                    losses[f'aux_query_ce_l{l}'] += (loss / b)

                    loss = dice_loss(query_log.softmax(dim=1), F.interpolate(soft_gt, scale_factor=1/16))
                    losses[f'aux_query_dice_l{l}'] += (loss / b)

        losses['total_loss'] = sum(losses.values())

        return losses
hkchengrex commented 12 months ago

At that time (early stage of development, very different models/architecture), we got OOM with standard supervision but not with point supervision. We get this from Mask2Former.

ret-1 commented 12 months ago

I re-try main_training with your config (change batch_size 16->8) using four 3090. But it seems that these two loss functions have nearly the same memory.

Point supervision is the first, XMem is the second.

image

If it's convenient, could you please try to replace the loss function with the original XMem one and see if it reduces the memory on your machine?

Thanks a lot!

hkchengrex commented 12 months ago

It does save memory during the initial development (we were using bipartite matching with cost matrices). You might be right that it does not save memory with the latest model. I will have to check this and update the paper if necessary. I cannot do this right now since I am quite busy with other stuff but thank you for letting me know.

hkchengrex commented 12 months ago

I briefly checked -- with the cache cleared, (batch size 1) the full version uses 264\~350 MB during loss computation and the point supervision version uses 41\~66 MB.

hkchengrex commented 12 months ago

So with cache, the extra memory cost might be covered by other temporary memory costs in the network.

ret-1 commented 12 months ago

Thank you so much for your cooperation!

Therefore, based on the conclusions of your attempts, is there not much memory difference between the two loss functions under normal training conditions (with the cache)? I was also very hopeful that point supervision would be very effective in reducing the memory, so I saw that your paper said "using only one-third of the memory during training" and tried to replace it quickly.

I'm guessing that the reason point supervision can solve the OOM in your initial version is probably because of bipartite matching.