jerpelhan / DAVE

MIT License
34 stars 3 forks source link

Training memory leak #11

Open x55admin opened 2 months ago

x55admin commented 2 months ago

Thank you for your code, the effect is great, but I am training under windows11 platform, encountered problems, so I modified the training script, but the script occupied by the video memory will increase after each iteration, resulting in the final use of memory as the video memory, program OOM, may I ask where the problem is my code or the program itself, thank you for your answer. Platform :windows 11 Graphics card: nvidia 4090 Python: 3.10 The torch, the torch - + cu12 2.3.1

import argparse
import gc
import math
import os
from time import perf_counter

import numpy as np
import skimage
import torch
from torch import distributed as dist
from torch import nn
from torch.utils.data import DataLoader
from torchvision.ops import box_iou

from models.box_prediction import BoxList, boxlist_nms
from models.dave import build_model
from utils.data import FSC147WithDensityMapDOWNSIZE
from utils.losses import Criterion, Detection_criterion
from utils.train_arg_parser import get_argparser

DATASETS = {
    'fsc147': FSC147WithDensityMapDOWNSIZE,
}

def generate_bbox(density_map, tlrb):
    bboxes = []
    for i in range(density_map.shape[0]):
        density = np.array((density_map)[i][0].cpu())
        dmap = np.array((density_map)[i][0].cpu())

        mask = dmap < np.max(dmap) / 3
        dmap[mask] = 0
        a = skimage.feature.peak_local_max(dmap, exclude_border=0)

        boxes = []
        scores = []
        b, l, r, t = tlrb[i]

        for x11, y11 in a:
            box = [y11 - b[x11][y11].item(), x11 - l[x11][y11].item(), y11 + r[x11][y11].item(),
                   x11 + t[x11][y11].item()]
            boxes.append(box)
            scores.append(
                1 - math.fabs(density[int(box[1]): int(box[3]), int(box[0]):int(box[2])].sum() - 1))

        b = BoxList(boxes, (density_map.shape[3], density_map.shape[2]))
        b.fields['scores'] = torch.tensor(scores)
        b = b.clip()
        b = boxlist_nms(b, b.fields['scores'], 0.55)

        bboxes.append(b)
    return bboxes

def reduce_dict(input_dict):
    with torch.no_grad():
        names = []
        values = []
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.all_reduce(values)
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict

def train(args):
    if args.skip_train:
        print("SKIPPING TRAIN")
        return
    rank = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
    assert args.reduction in [4, 8, 16]

    model = build_model(args).to(device)

    # model.load_state_dict(
    #     torch.load(os.path.join(args.model_path, args.model_name + '.pth'))['model'], strict=False
    # )

    backbone_params = dict()
    non_backbone_params = dict()
    fcos_params = dict()
    feat_comp = dict()
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if 'backbone' in n:
            backbone_params[n] = p
        elif 'box_predictor' in n:
            fcos_params[n] = p
        elif 'feat_comp' in n:
            feat_comp[n] = p
        else:
            non_backbone_params[n] = p

    optimizer = torch.optim.AdamW(
        [
            {'params': fcos_params.values(), 'lr': args.lr},
        ],
        lr=args.lr,
        weight_decay=args.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop, gamma=0.25)
    if args.resume_training:
        checkpoint = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth'))
        print(model.state_dict().keys())
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['model'].items()})
        start_epoch = checkpoint['epoch']
        best_mAP = checkpoint['best_val_ae']
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
    else:
        start_epoch = 0
        best = 10000000000000
        best_mAP = 0

    criterion = Criterion(args)
    aux_criterion = Criterion(args, aux=True)
    det_criterion = Detection_criterion(
        [[-1, args.fcos_pred_size], [64, 128], [128, 256], [256, 512], [512, 100000000]],  # config.sizes,
        'giou',  # config.iou_loss_type,
        True,  # config.center_sample,
        [1],  # config.fpn_strides,
        5,  # config.pos_radius,
    )
    train = DATASETS[args.dataset](
        args.data_path,
        args.image_size,
        split='train',
        num_objects=args.num_objects,
        tiling_p=args.tiling_p,
        zero_shot=args.zero_shot or args.orig_dmaps,
        skip_cars=args.skip_cars,
    )
    val = DATASETS[args.dataset](
        args.data_path,
        args.image_size,
        split='val',
        num_objects=args.num_objects,
        tiling_p=args.tiling_p,
        zero_shot=args.zero_shot or args.orig_dmaps,
    )
    train_loader = DataLoader(
        train,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.num_workers,
    )
    val_loader = DataLoader(
        val,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.num_workers,
    )
    print("NUM STEPS", len(train_loader) * args.epochs)
    print(rank, len(train_loader))
    for epoch in range(start_epoch + 1, args.epochs + 1):
        print('epoch:', epoch)
        start = perf_counter()
        # train_losses = {k: torch.tensor(0.0).to(device) for k in criterion.losses.keys()}
        # val_losses = {k: torch.tensor(0.0).to(device) for k in criterion.losses.keys()}
        # aux_train_losses = {k: torch.tensor(0.0).to(device) for k in aux_criterion.losses.keys()}
        # aux_val_losses = {k: torch.tensor(0.0).to(device) for k in aux_criterion.losses.keys()}
        train_ae = torch.tensor(0.0).to(device)
        val_ae = torch.tensor(0.0).to(device)
        mAP = torch.tensor(0.0).to(device)

        model.train()

        # for index, (img, bboxes, density_map, ids, scale_x, scale_y, _) in enumerate(train_loader):
        for index, (img, bboxes, density_map, _, _, _, _) in enumerate(train_loader):
            img = img.to(device)
            bboxes = bboxes.to(device)
            density_map = density_map.to(device)
            targets = BoxList(bboxes, (args.image_size, args.image_size), mode='xyxy').to(device).resize(
                (args.fcos_pred_size, args.fcos_pred_size))
            targets.fields['labels'] = [1 for __ in range(args.batch_size * 2)]
            optimizer.zero_grad()
            outR, aux_R, tblr, location = model(img, bboxes)

            if args.normalized_l2:
                with torch.no_grad():
                    num_objects = density_map.sum()
            else:
                num_objects = None

            main_losses = criterion(outR, density_map, bboxes, num_objects)
            aux_losses = [
                aux_criterion(aux, density_map, bboxes, num_objects) for aux in aux_R
            ]
            det_loss = det_criterion(location, tblr, targets)
            del targets
            loss = (
                    sum([ml for ml in main_losses.values()]) * 0 +
                    sum([al for alls in aux_losses for al in alls.values()]) * 0 +
                    det_loss  # + l
            )
            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            # train_losses = {
            #     k: train_losses[k] + main_losses[k] * img.size(0) for k in train_losses.keys()
            # }
            # aux_train_losses = {
            #     k: aux_train_losses[k] + sum([a[k] for a in aux_losses]) * img.size(0)
            #     for k in aux_train_losses.keys()
            # }
            train_ae += torch.abs(
                density_map.flatten(1).sum(dim=1) - outR.flatten(1).sum(dim=1)
            ).sum()
            if(index % 20 == 0):
                print('step:' + str(index) + '/' + str(len(train_loader)) + '  ' + 'loss:' + str(loss.cpu().detach().numpy()))

        model.eval()
        with torch.no_grad():
            # for index, (img, bboxes, density_map, ids, scale_x, scale_y, _) in enumerate(val_loader):
            for index, (img, bboxes, density_map, ids, scale_x, scale_y, _) in enumerate(val_loader):
                gt_bboxes, _ = val.get_gt_bboxes(ids)
                img = img.to(device)
                bboxes = bboxes.to(device)
                density_map = density_map.to(device)

                optimizer.zero_grad()

                outR, aux_R, tblr, location = model(img, bboxes)

                boxes_pred = generate_bbox(outR, tblr)

                for iii in range(len(gt_bboxes)):
                    boxes_pred[iii].box = boxes_pred[iii].box * 1 / torch.tensor(
                        [scale_y[iii], scale_x[iii], scale_y[iii], scale_x[iii]])
                    mAP += box_iou(gt_bboxes[iii], boxes_pred[iii].box).max(dim=1)[0].sum() / gt_bboxes[iii].shape[
                        1]

                if args.normalized_l2:
                    with torch.no_grad():
                        num_objects = density_map.sum()
                else:
                    num_objects = None
                main_losses = criterion(outR, density_map, bboxes, num_objects)
                # aux_losses = [
                #     aux_criterion(aux, density_map, bboxes, num_objects) for aux in aux_R
                # ]
                # val_losses = {
                #     k: val_losses[k] + main_losses[k] * img.size(0) for k in val_losses.keys()
                # }
                # aux_val_losses = {
                #     k: aux_val_losses[k] + sum([a[k] for a in aux_losses]) * img.size(0)
                #     for k in aux_val_losses.keys()
                # }
                val_ae += torch.abs(
                    density_map.flatten(1).sum(dim=1) - outR.flatten(1).sum(dim=1)
                ).sum()
                print('step:' + str(index) + '/' + str(len(val_loader)) + '  ' + 'loss:' + str(main_losses.cpu().detach().numpy()))

        # train_losses = reduce_dict(train_losses)
        # val_losses = reduce_dict(val_losses)
        # aux_train_losses = reduce_dict(aux_train_losses)
        # aux_val_losses = reduce_dict(aux_val_losses)

        scheduler.step()

        if rank == 0:
            end = perf_counter()
            best_epoch = False

            if mAP > best_mAP:
                best_mAP = mAP
                checkpoint = {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'best_val_ae': val_ae.item() / len(val)
                }
                torch.save(
                    checkpoint,
                    os.path.join(args.model_path, f'{args.det_model_name}.pth')
                )
                best_epoch = True

            print("Epoch", epoch)
            # print({k: v.item() / len(train) for k, v in train_losses.items()})
            # print({k: v.item() / len(val) for k, v in val_losses.items()})
            # print({k: v.item() / len(train) for k, v in aux_train_losses.items()})
            # print({k: v.item() / len(val) for k, v in aux_val_losses.items()})
            print(
                train_ae.item() / len(train),
                val_ae.item() / len(val),
                end - start,
                'best' if best_epoch else '',
            )
            print("det_sc:", mAP / len(val))
            print("********")

    if args.skip_test:
        dist.destroy_process_group()

if __name__ == '__main__':
    parser = argparse.ArgumentParser('DAVE', parents=[get_argparser()])
    args = parser.parse_args()
    print(args)
    train(args)
jerpelhan commented 1 month ago

Sorry for the late reply. I had no issues training on Linux, and unfortunately, I cannot test it on Windows.