dansuh17 / segan-pytorch

SEGAN pytorch implementation https://arxiv.org/abs/1703.09452
GNU General Public License v3.0
106 stars 32 forks source link

how to predict #11

Open mechi33 opened 6 years ago

mechi33 commented 6 years ago

Hi, thank you for the nice implementation.

Regarding predict, should I use Tensorflow original version? (https://github.com/santi-pdp/segan/blob/master/main.py)

Your REDADME explain only about training, so let me confirm how to predict the model.

Regards,

dansuh17 commented 6 years ago

By prediction, do you mean predicting whether an audio is noisy or not? (discriminator) Or, do you mean denoising? (generator) Either case, you can load the models saved at the end of every epoch using torch.load().

discriminator = torch.load('discriminator-5.pkl')
output = discriminator(audio)
mechi33 commented 6 years ago

Thank you very much for your kind reply. I tried your code below and face error.

gen=torch.load('generator-7.pkl')
nois_data='p232_001.wav'
output=gen(nois_data)
TypeError: 'collections.OrderedDict' object is not callable

Could you give some advice on this?

Thanks,

dansuh17 commented 6 years ago

My bad. Since we're saving the .state_dict(), you should load the state_dict.

# the model should have all parameters loaded
model.load_state_dict('generator-7.pkl')

From ImageNet example