nl8590687 / ASRT_SpeechRecognition

A Deep-Learning-Based Chinese Speech Recognition System 基于深度学习的中文语音识别系统
https://asrt.ailemon.net
GNU General Public License v3.0
7.85k stars 1.9k forks source link

请问readdata.py中data_generator()里面为什么要写while True弄成死循环? #257

Open DQ2020scut opened 3 years ago

DQ2020scut commented 3 years ago
    while True:
        X = np.zeros((batch_size, audio_length, 200, 1), dtype = np.float)
        #y = np.zeros((batch_size, 64, self.SymbolNum), dtype=np.int16)
        y = np.zeros((batch_size, 64), dtype=np.int16)

        #generator = ImageCaptcha(width=width, height=height)
        input_length = []
        label_length = []

        for i in range(batch_size):
            ran_num = random.randint(0,self.DataNum - 1) # 获取一个随机数
            data_input, data_labels = self.GetData(ran_num)  # 通过随机数取一个数据
            #data_input, data_labels = self.GetData((ran_num + i) % self.DataNum)  # 从随机数开始连续向后取一定数量数据

            # 关于下面这一行取整除以8 并加8的余数,在实际中如果遇到报错,可尝试只在有余数时+1,没有余数时+0,或者干脆都不加,只留整除
            input_length.append(data_input.shape[0] // 8 + data_input.shape[0] % 8)
            #print(data_input, data_labels)
            #print('data_input长度:',len(data_input))

            X[i,0:len(data_input)] = data_input
            #print('data_labels长度:',len(data_labels))
            #print(data_labels)
            y[i,0:len(data_labels)] = data_labels
            #print(i,y[i].shape)
            #y[i] = y[i].T
            #print(i,y[i].shape)
            label_length.append([len(data_labels)])

        label_length = np.matrix(label_length)
        input_length = np.array([input_length]).T
        #input_length = np.array(input_length)
        #print('input_length:\n',input_length)
        #X=X.reshape(batch_size, audio_length, 200, 1)
        #print(X)
        yield [X, y, input_length, label_length ], labels
    pass
nl8590687 commented 3 years ago

主要是这样一种写法吧,一种有放回地随机抽样抽取数据的方式,也可以改用epoch方式啊,只不过现在用epoch的方式多