Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.07k stars 308 forks source link

训练过程中显存增加问题 #510

Closed nicheng0019 closed 4 months ago

nicheng0019 commented 4 months ago

Describe the bug

我使用计图框架训练的代码大致是如下结构:

for train_idx in range(train_epoch):
    model.train()
    for (images, target) in train_loader:
        y = model(images)
        loss = loss_fn(y, target)
        optimizer.zero_grad()
        optimizer.backward(loss)
        optimizer.step()

训练的时候发现在train_idx 等于0时,里面一层的 for循环,显存的使用在随着batch不断增长,只到第一个epoch训练结束,从第二个epoch开始显存就不增长了,按照我的理解,应该是每个batch的训练,显存都不增长,请问是我的代码的问题还是计图框架就是这个机制?谢谢

nicheng0019 commented 4 months ago

已找到问题原因,是自己代码的问题,抱歉。