Open DQ2020scut opened 3 years ago
while True: X = np.zeros((batch_size, audio_length, 200, 1), dtype = np.float) #y = np.zeros((batch_size, 64, self.SymbolNum), dtype=np.int16) y = np.zeros((batch_size, 64), dtype=np.int16) #generator = ImageCaptcha(width=width, height=height) input_length = [] label_length = [] for i in range(batch_size): ran_num = random.randint(0,self.DataNum - 1) # 获取一个随机数 data_input, data_labels = self.GetData(ran_num) # 通过随机数取一个数据 #data_input, data_labels = self.GetData((ran_num + i) % self.DataNum) # 从随机数开始连续向后取一定数量数据 # 关于下面这一行取整除以8 并加8的余数,在实际中如果遇到报错,可尝试只在有余数时+1,没有余数时+0,或者干脆都不加,只留整除 input_length.append(data_input.shape[0] // 8 + data_input.shape[0] % 8) #print(data_input, data_labels) #print('data_input长度:',len(data_input)) X[i,0:len(data_input)] = data_input #print('data_labels长度:',len(data_labels)) #print(data_labels) y[i,0:len(data_labels)] = data_labels #print(i,y[i].shape) #y[i] = y[i].T #print(i,y[i].shape) label_length.append([len(data_labels)]) label_length = np.matrix(label_length) input_length = np.array([input_length]).T #input_length = np.array(input_length) #print('input_length:\n',input_length) #X=X.reshape(batch_size, audio_length, 200, 1) #print(X) yield [X, y, input_length, label_length ], labels pass
主要是这样一种写法吧,一种有放回地随机抽样抽取数据的方式,也可以改用epoch方式啊,只不过现在用epoch的方式多