Oneflow-Inc / one-yolov5

A more efficient yolov5 with oneflow backend 🎉🎉🎉
https://start.oneflow.org/oneflow-yolo-doc
GNU General Public License v3.0
213 stars 19 forks source link

ddp 模式使用deepcopy(de_parallel(m)) 发生内存泄露 #101

Closed ccssu closed 1 year ago

ccssu commented 1 year ago

问题

问题描述 : ddp 模式使用deepcopy(de_parallel(m)) 发生内存泄露 one-yolo 多卡跑的时候出现个问题,第一张卡的显存占用会不断缓慢增长,其他卡的显存占用保持稳定, 通过yolov5n模型发生错误定位(即 https://github.com/Oneflow-Inc/OneTeam/issues/1856 ) , 通过yolov5l,yolov5x这些相对大点的模型复现的结论和yolov5n不一致, 最终确认ddp 模式使用deepcopy(de_parallel(m)) 发生内存泄露,可复现代码如下.

最小可复现代码

启动指令: python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py

oneflow版本: 使用 oneflow master 分支即可

ddp_train.py
import oneflow as flow
import oneflow.nn as nn
from oneflow.nn.parallel import DistributedDataParallel as ddp
from copy import deepcopy

train_x = [
    flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),
    flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),
]
train_y = [
    flow.tensor([[8], [13]], dtype=flow.float32),
    flow.tensor([[26], [9]], dtype=flow.float32),
]

class Model(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.lr = 0.01
        self.iter_count = 500*1024
        self.w = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))

    def forward(self, x):
        x = flow.matmul(x, self.w)
        return x

def is_parallel(model):
    # Returns True if model is of type DDP
    return type(model) in (nn.parallel.DistributedDataParallel,)

def de_parallel(model):
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
    return model.module if is_parallel(model) else model

m = Model().to("cuda")
m = ddp(m)
loss = flow.nn.MSELoss(reduction="sum")
optimizer = flow.optim.SGD(m.parameters(), m.lr)

for i in range(0, m.iter_count):
    rank = flow.env.get_rank()
    x = train_x[rank].to("cuda")
    y = train_y[rank].to("cuda")

    y_pred = m(x)
    l = loss(y_pred, y)
    if (i + 1) % 500 == 0:
        print(f"{i+1}/{m.iter_count} loss:{l}")
    if rank in {-1,0}:
        ckpt={
            "model":deepcopy(de_parallel(m))
        }
        del ckpt
    optimizer.zero_grad()
    l.backward()
    optimizer.step()

注意:

  1. 复现启动指令: python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py
  2. oneflow版本信息: 使用最新oneflow master 即可,这里使用( flow.version = 0.8.1.dev20230102+cu117)
  3. 注释掉 "model":deepcopy(de_parallel(m)) 这一行代码不会发生内存泄露。
  4. 可复现代码对应one-yolov5项目中代码段: https://github.com/Oneflow-Inc/one-yolov5/blob/9d908a104e00eef5cc9927be06f4d3f1d31ca517/train.py#L416-L435

版本信息

BBuf commented 1 year ago

这个问题有解决么? @ccssu

ccssu commented 1 year ago

这个问题有解决么? @ccssu

解决了,内存泄露应该是没问题, 测试过了 oneflow 0.9.0 版本不用 deepcopy 对模型是没影响的。下图20h左右的一次训练截图(wandb log: ) image