santi-pdp / segan

Speech Enhancement Generative Adversarial Network in TensorFlow
MIT License
816 stars 281 forks source link

OOM with small batch size on GPU #65

Open Mega4alik opened 5 years ago

Mega4alik commented 5 years ago

Hi, I'm on @lordet01 implementation - https://github.com/lordet01/segan

In model.py the following part of code leads to OOM even with small batch_size. Could anyone solve this problem?

while not coord.should_stop():
                start = 0 #timeit.default_timer()
                if counter % config.save_freq == 0:
                    for d_iter in range(self.disc_updates):
                        _d_opt, _d_sum, \
                        d_fk_loss, \
                        d_rl_loss = self.sess.run([d_opt, self.d_sum,
                                                   self.d_fk_losses[0],
                                                   #self.d_nfk_losses[0],
                                                   self.d_rl_losses[0]])
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip)
                        #d_nfk_loss, \

                    # now G iterations
                    _g_opt, _g_sum, \
                    g_adv_loss, \
                    g_l1_loss = self.sess.run([g_opt, self.g_sum,
                                               self.g_adv_losses[0],
                                               self.g_l1_losses[0]])
                else:
                    for d_iter in range(self.disc_updates):
                        _d_opt, \
                        d_fk_loss, \
                        d_rl_loss = self.sess.run([d_opt,
                                                   self.d_fk_losses[0],
                                                   #self.d_nfk_losses[0],
                                                   self.d_rl_losses[0]])
                        #d_nfk_loss, \
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip)

                    _g_opt, \
                    g_adv_loss, \
                    g_l1_loss = self.sess.run([g_opt, self.g_adv_losses[0],

                                               self.g_l1_losses[0]])
LexiYIN commented 3 years ago

Same problem,did you solve it?