WongKinYiu / YOLO

An MIT rewrite of YOLOv9
MIT License
495 stars 52 forks source link

Memory Leak? or increasing bug #59

Open nathanWagenbach opened 1 month ago

nathanWagenbach commented 1 month ago

Describe the bug

I am running a multi-gpu training and the amount of RAM used by the pt_main_thread processes is continually growing. I am training on the COCO17 dataset. The problem seems to scale with the number of workers I use but it always occurs. Total memory usage stars off around 20GB and then increases by about 10GB per epoch.

The command I am using is: torchrun --nproc_per_node=3 yolo/lazy.py task=train device=[0,1,2] task.data.batch_size=16 name=v9-dev-mgpu_005 cpu_num=2 image_size=[640,640]

Has multi-gpu training been tested? It seems like it could be related to this issue https://github.com/pytorch/pytorch/issues/13246#issuecomment-1364587359

System Info (please complete the following ## information):

Abdul-Mukit commented 1 month ago

I just faced the same problem using only CPU. This needs to be addressed, otherwise can't proceed further. Looking into this next. If anyone finds a clue please ping here. @nathanWagenbach did you debug this issue? Any updates?

Abdul-Mukit commented 1 month ago

I tried something like this https://github.com/pytorch/pytorch/issues/13246#issuecomment-436632186 I don't think it is related to the dataloader or mult-processing. My settings were device: cpu cpu_num: 0. Still happened.

Tried this

import tracemalloc
tracemalloc.start()

def loop_one_epoch(dataloader):
    total_samples = 0
    for batch_size, images, targets, *_ in dataloader:
        total_samples += len(images)
    print(f"Epoch finished. Total samples: {total_samples}")

def loop_n_epochs(n_epochs, dataloader):
    for epoch_idx in range(n_epochs):
        print(f"Starting epoch {epoch_idx + 1} / {n_epochs}")
        loop_one_epoch(dataloader)
    pass

@hydra.main(config_path="config", config_name="config", version_base=None)
def main(cfg: Config):
    progress = ProgressLogger(cfg, exp_name=cfg.name)
    device, use_ddp = get_device(cfg.device)
    dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)

    snap1 = tracemalloc.take_snapshot()
    loop_n_epochs(n_epochs=cfg.task.epoch, dataloader=dataloader)
    snap2 = tracemalloc.take_snapshot()

    top_stats = snap2.compare_to(snap1, 'lineno')
    print("Top 10 memory usage stats:")
    for stat in top_stats[:10]:
        print(stat)

if __name__ == "__main__":
    main()

Got:

Starting epoch 1 / 3
Epoch finished. Total samples: 1545
Starting epoch 2 / 3
Epoch finished. Total samples: 1545
Starting epoch 3 / 3
Epoch finished. Total samples: 1545
Top 10 memory usage stats:
/usr/lib/python3.9/tracemalloc.py:558: size=203 KiB (+203 KiB), count=4002 (+4001), average=52 B
.venv/lib/python3.9/site-packages/torch/serialization.py:1525: size=919 KiB (-115 KiB), count=10690 (-2063), average=88 B
<frozen importlib._bootstrap_external>:587: size=98.3 KiB (+51.0 KiB), count=1138 (+610), average=88 B
/usr/lib/python3.9/copy.py:142: size=0 B (-42.9 KiB), count=0 (-915)
/usr/lib/python3.9/copy.py:264: size=112 KiB (-41.4 KiB), count=1598 (-884), average=72 B
/usr/lib/python3.9/multiprocessing/reduction.py:40: size=12.9 KiB (+12.9 KiB), count=79 (+79), average=167 B
/usr/lib/python3.9/threading.py:892: size=10048 B (+10048 B), count=23 (+23), average=437 B
/usr/lib/python3.9/threading.py:912: size=8928 B (+8928 B), count=16 (+16), average=558 B
.venv/lib/python3.9/site-packages/torch/multiprocessing/reductions.py:485: size=8848 B (+8848 B), count=158 (+158), average=56 B
/usr/lib/python3.9/threading.py:954: size=8424 B (+8424 B), count=16 (+16), average=526 B

Which indicates no extra memory was added, AFIK. The memory usage graph in task manager bumped only a little bit. Unlike when the actual "leak" happens and fillup the memory usage to max gradually.

So far I am not sure it this is a leak at all. I tried measuring the memory consumption of train_one_epoch for a very small dataset. It didn't seem like it was leaking memory. Instead, when calling train_one_epoch on a very large dataset, the momory consumption reaches max while that one epoch is still running. I could be wrong. It is not related to "mult-gpu" or "cpu_num > 0". The issue is happening even if you are using no gpu and cpu_num=0. Every call to train_on_batch increases the memory by a few GB. I am guessing the bug has something to do with the optimizer.

@henrytsui000 what do you think might be the source?

Abdul-Mukit commented 1 month ago

The problem is in train_one_batch. I measured the memory usage for each line while using cpu_num=0 and no-gpu. Batch size is 16, which makes image batch shape = [16, 3, 640, 640].

Here are the two lines where this leak is happening: https://github.com/WongKinYiu/YOLO/blob/8228669808a626fc5f9c233fdb35550b5e041fae/yolo/tools/solver.py#L74 https://github.com/WongKinYiu/YOLO/blob/8228669808a626fc5f9c233fdb35550b5e041fae/yolo/tools/solver.py#L79

I present my measurements in the following table. Here are the colum definitions: Column0 = Batch Number Column1 = Before call to predicts = self.model(images) Column2 = After call to predicts = self.model(images) Column3 = After call to self.scaler.scale(loss).backward()

B0 09.40GB -> 20.00GB -> 15.90GB B1 15.90GB -> 21.10GB -> 18.10GB B2 18.10GB -> 22.80GB -> 20.02GB

In the above table, for batch 0 (B0) before call to the predicts = self.model(images) the mremory consumption was 9.4GB. After executing that line, memory jumps to 20GB. That is 10GB increase. I am not sure if that is normal or optimal to begin with. It was expected that after back-propagation (once train_one_batch finishes) the mremory consumption will drop down to 9.4GB again. Instead, after call to self.scaler.scale(loss).backward(), the memory consumption dropped to 15.9GB. That is a 6.5GB memory leak for just 16 images. You can see that for batch-1 and batch-2, memory leak is 2.2GB and 1.91GB respectively. @WongKinYiu can you please give me a hint of what I can look for next, to solve this problem.

Abdul-Mukit commented 1 month ago

Tried out #83. Worked. Memory usage is still quite high though. Constant usage of around 30GB for batch shape of [16, 3, 640, 640].