Closed nicheng0019 closed 4 months ago
我使用计图框架训练的代码大致是如下结构:
for train_idx in range(train_epoch): model.train() for (images, target) in train_loader: y = model(images) loss = loss_fn(y, target) optimizer.zero_grad() optimizer.backward(loss) optimizer.step()
训练的时候发现在train_idx 等于0时,里面一层的 for循环,显存的使用在随着batch不断增长,只到第一个epoch训练结束,从第二个epoch开始显存就不增长了,按照我的理解,应该是每个batch的训练,显存都不增长,请问是我的代码的问题还是计图框架就是这个机制?谢谢
已找到问题原因,是自己代码的问题,抱歉。
Describe the bug
我使用计图框架训练的代码大致是如下结构:
训练的时候发现在train_idx 等于0时,里面一层的 for循环,显存的使用在随着batch不断增长,只到第一个epoch训练结束,从第二个epoch开始显存就不增长了,按照我的理解,应该是每个batch的训练,显存都不增长,请问是我的代码的问题还是计图框架就是这个机制?谢谢