Closed yh1008 closed 7 years ago
Hi,
trGen.numFeats
stands for the total number of training frames. The built-in method fit_generator
of Keras expects data from a Python generator, but is expected to loop over its data indefinitely. Usual Python generators raise a StopIteration
exception at the end of data, but we can't do it here. The generator needs to continue to fetch data. So there's no way to keep track of how many epochs elapsed during training, except by counting sample_per_epoch
. Just before the end of first epoch, when fit_generator
internally calls next()
method on trGen
, it must fetch only the residual number of frames in the training data, even though they may not be sufficient to form a batch. For e.g. there are 100 training frames and the batch size is 32. After fetching data thrice, fit_generator
processed 32*3=96 frames. When it calls next()
, trGen
must fetch only the remaining 4 samples, if we had set samples_per_epoch
as 100. If it fetches more samples than 4, fit_generator
gives a warning. And in the next fetch, trGen
should start from the beginning and give 1-32 frames.
You must be training on 46 hours of data (assuming the frame rate were 100 per second). A decent machine with a GPU shouldn't take more than 12 hours to train the network. If your machine has less RAM, try setting self.maxSplitDataSize
inside dataGenerator.py
to a lower value. Then it loads fewer utterances into memory at once, but the number of IO operations increases per epoch. On the other hand If it has good memory but slower IO to its hard drive, then try setting it to a higher value.
Hope this helps.
Thank you so much for the swift response!
You are absolutely right, my training size is ~46 hours. Thanks to your post, I just realized that I accidently installed the tensorflow-cpu not tensorflow-gpu...
Thanks again for the explaining fit-generator
, dataGenerator
, triGen.numFeats
!
Hi Mr Kumar,
Inside
steps_kt/train.py
, when you callm.fit_generator
you set thesample_per_epoch
totrGen.numFeats
, like the following:Is there a particular reason that you
sample_per_epoch
is setted in this way? My concern is that seems like my dataset has 16,563,999 number of trGen.numFeats, and the network is too slow to train.My question is
trGen.numFeats
stand for?samples_per_epoch
?Any input and advice will be greatly appreciated! Thanks in advance,