wangtianrui / DCCRN

implementation of "DCCRN-Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement" by pytorch
49 stars 16 forks source link

模型训练过程中ComplexBatchNorm函数导致内存爆炸 #9

Open RuqiaoLiu opened 2 years ago

RuqiaoLiu commented 2 years ago

您好:

当我在模型中使用ComplexBatchNorm时,训练一段时间内存就会爆炸。但是我把ComplexBatchNorm换成torch自带的BatchNorm就不会出现该问题。请问您在训练DCCRN的时候没有出现这个情况吗?

谢谢; 祝好!

wangtianrui commented 2 years ago

额。我这儿没有用ComplexBatchNorm。不过从你的现象来看,您是不是用的最开始我复现的代码。那个代码有问题,你用现在这个仓库的就行了。现在的模型是官方(西工大)公开的代码。

RuqiaoLiu commented 2 years ago

嗯,我就是使用的您现在这个仓库。我发现把ComplexBatchNorm中的track_running_stats设置为False就能避免这个问题。但是track_running_stats在torch中的BatcNorm都是默认为True的。

wangtianrui commented 2 years ago

哦 抱歉。我看到了。我理解错误了。这个问题我没遇到过。我以为你看的是我以前写的一个类。刚刚翻了一下,才发现这里面也有。原来他写的也有这个问题啊。我没有试过你的这个方案。我以前那个方案是因为梯度的问题,每次梯度都被保留下来了,需要清除一次,希望能够帮到你

RuqiaoLiu commented 2 years ago

噢,请问ComplexBatchNorm的实现会影响梯度的累积吗?我理解的是梯度计算是体现在train.py里面。

RuqiaoLiu commented 2 years ago

我发现speechbrain中实现的CBatchNorm里,它的forward函数中会对track_running_stats=True时产生的Buffer参数做一个.detach()操作。是因为这个导致的吗?

wangtianrui commented 2 years ago

有可能是的,需要你尝试一下。

SolituderAlex commented 2 years ago

请问数据集在哪里下载呀

596227421 commented 2 years ago

哦 抱歉。我看到了。我理解错误了。这个问题我没遇到过。我以为你看的是我以前写的一个类。刚刚翻了一下,才发现这里面也有。原来他写的也有这个问题啊。我没有试过你的这个方案。我以前那个方案是因为梯度的问题,每次梯度都被保留下来了,需要清除一次,希望能够帮到你

请问作者,以前那个代码需要如何修改才能够正常运行,我现在跑到半个epoch就卡住不动了,查看gpu内存发现并没有超出,代码中也有optimizer.zero_grad()清除梯度

tongyu1 commented 1 year ago

这个内存问题是怎么定位到的,我发现我也遇到类似的内存问题,但是定位不到是哪里导致的内存爆炸