josephjaspers / blackcat_tensors

Matrix-Vector Library Designed for Neural Network Construction. cuda (gpu) support, openmp (multithreaded cpu) support, partial support of BLAS, expression template based implementation PTX code generation identical to hand written kernels, and support for auto-differentiation
12 stars 4 forks source link

the forward_propagation、predict function batched forward and the signal_predict function, same input but predict output different. #45

Closed xinsuinizhuan closed 4 years ago

josephjaspers commented 4 years ago

So according to this: https://stackoverflow.com/questions/39196945/in-keras-when-does-lstm-state-reset-in-the-call-to-model-predict

It seems in keras the same cell-state is used for both predicting and training. Currently predicting (same batch_size) will use the same internal cell_state. (Like keras) However using 'single_predict' will use a different cell-state. (like keras)

Do you need single_predict to use the same cell-state as well?

xinsuinizhuan commented 4 years ago

I think it should be same cell. how about keras predict use one imge?

josephjaspers commented 4 years ago

So according to this: https://stackoverflow.com/questions/43882796/when-does-keras-reset-an-lstm-state

Keras (by default) resets its state after each sequence of batches. (For us, that would be after each call to 'update_weights' it zeroes that cell-state. (I think)

xinsuinizhuan commented 4 years ago

Sorry, now the result seems worse, predict and simgle_predict function output: Neural Network architecture: LSTM: inputs: 960 outputs: 1024 LSTM: inputs: 1024 outputs: 512 LSTM: inputs: 512 outputs: 216 FeedForward: inputs: 216 outputs: 192 Logistic: inputs: 192 outputs: 192 Output_Layer: inputs: 192 outputs: 192

training... imagesinput real data:------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] imagesoutput real data:------------------------------------ [[0.380952, 0.337662, 0.385281, 0.341991, 0.385281, 0.354978, 0.341991, 0.393939, 0.346320, 0.346320, 0.389610, 0.341991, 0.393939, 0.346320, 0.337662, 0.385281, 0.341991, 0.398268, 0.151515, 0.580087, 0.337662, 0.337662, 0.389610, 0.359307, 0.385281, 0.337662, 0.385281, 0.337662, 0.380952, 0.346320, 0.341991, 0.389610, 0.333333, 0.393939, 0.333333, 0.380952, 0.341991, 0.337662, 0.385281, 0.337662, 0.385281, 0.333333, 0.385281, 0.333333, 0.337662, 0.385281, 0.324675, 0.376623, 0.333333, 0.376623, 0.329004, 0.367965, 0.320346, 0.372294, 0.329004, 0.367965, 0.324675, 0.372294, 0.320346, 0.372294, 0.324675, 0.367965, 0.324675, 0.367965, 0.298701, 0.307359, 0.298701, 0.341991, 0.307359, 0.350649, 0.298701, 0.350649, 0.298701, 0.354978, 0.298701, 0.346320, 0.303030, 0.350649, 0.303030, 0.346320, 0.307359, 0.346320, 0.307359, 0.346320, 0.307359, 0.380952, 0.333333, 0.376623, 0.333333, 0.380952, 0.333333, 0.376623, 0.341991, 0.376623, 0.333333, 0.380952] [0.333333, 0.380952, 0.337662, 0.337662, 0.389610, 0.333333, 0.142857, 0.601732, 0.341991, 0.350649, 0.389610, 0.341991, 0.385281, 0.346320, 0.337662, 0.385281, 0.333333, 0.393939, 0.333333, 0.393939, 0.341991, 0.350649, 0.346320, 0.406926, 0.346320, 0.333333, 0.147186, 0.580087, 0.389610, 0.333333, 0.350649, 0.380952, 0.341991, 0.393939, 0.341991, 0.389610, 0.346320, 0.389610, 0.346320, 0.385281, 0.341991, 0.393939, 0.341991, 0.389610, 0.346320, 0.380952, 0.337662, 0.385281, 0.333333, 0.380952, 0.341991, 0.333333, 0.376623, 0.376623, 0.333333, 0.376623, 0.329004, 0.341991, 0.380952, 0.385281, 0.337662, 0.337662, 0.385281, 0.337662, 0.385281, 0.337662, 0.385281, 0.341991, 0.389610, 0.341991, 0.385281, 0.333333, 0.385281, 0.337662, 0.380952, 0.337662, 0.385281, 0.341991, 0.380952, 0.337662, 0.376623, 0.303030, 0.346320, 0.307359, 0.350649, 0.307359, 0.346320, 0.303030, 0.350649, 0.329004, 0.372294, 0.303030, 0.341991, 0.303030, 0.337662, 0.303030]]

forward_propagation inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] forward_propagation output predict data------------------------------------ [0.310533, 0.383703, 0.365227, 0.359723, 0.369998, 0.347733, 0.329489, 0.298469, 0.457006, 0.332383, 0.380970, 0.364110, 0.364936, 0.344942, 0.339252, 0.365484, 0.342304, 0.412141, 0.311666, 0.462750, 0.335688, 0.339238, 0.360699, 0.360489, 0.366455, 0.384661, 0.315686, 0.349905, 0.361444, 0.307811, 0.338125, 0.367147, 0.306659, 0.370067, 0.395402, 0.359471, 0.372835, 0.254907, 0.390919, 0.294525, 0.391681, 0.317846, 0.395277, 0.345315, 0.325443, 0.403894, 0.395480, 0.421372, 0.364338, 0.321313, 0.344960, 0.447952, 0.306943, 0.314869, 0.365779, 0.376821, 0.290613, 0.294666, 0.285751, 0.352383, 0.268660, 0.272055, 0.425920, 0.376848, 0.373882, 0.392094, 0.380374, 0.266039, 0.355668, 0.259666, 0.295560, 0.373857, 0.353257, 0.401727, 0.281745, 0.324097, 0.289450, 0.303454, 0.273620, 0.415892, 0.304918, 0.303497, 0.324329, 0.357835, 0.285150, 0.284309, 0.329921, 0.293902, 0.356708, 0.354956, 0.383000, 0.324327, 0.337800, 0.350772, 0.334236, 0.335106, 0.307106, 0.272402, 0.321818, 0.323625, 0.394443, 0.397656, 0.310069, 0.456427, 0.433450, 0.376627, 0.333204, 0.367536, 0.384775, 0.387094, 0.329576, 0.330101, 0.311769, 0.382324, 0.348715, 0.338723, 0.329414, 0.337339, 0.333832, 0.421364, 0.399090, 0.369570, 0.247406, 0.498754, 0.436079, 0.353101, 0.353068, 0.383053, 0.370463, 0.404493, 0.437789, 0.431844, 0.350418, 0.405932, 0.372235, 0.321584, 0.374972, 0.444291, 0.368342, 0.381174, 0.307321, 0.278165, 0.335184, 0.308720, 0.259972, 0.377931, 0.415331, 0.366818, 0.348073, 0.365168, 0.368561, 0.377450, 0.388073, 0.349286, 0.330372, 0.393218, 0.353819, 0.321392, 0.348001, 0.354104, 0.274504, 0.332613, 0.319378, 0.366077, 0.337131, 0.002016, 0.308405, 0.273303, 0.328251, 0.301632, 0.376895, 0.299215, 0.405439, 0.324014, 0.339204, 0.385824, 0.392870, 0.301711, 0.332651, 0.331799, 0.373001, 0.304462, 0.340968, 0.286348, 0.399842, 0.304851, 0.388724, 0.248480, 0.334762, 0.378910, 0.407318, 0.372582] predict MAPE loss: 0.028148 predict inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] predict output predict data------------------------------------ [0.295035, 0.402336, 0.266329, 0.410106, 0.398834, 0.394344, 0.367269, 0.441058, 0.368736, 0.491675, 0.310015, 0.394089, 0.360559, 0.348193, 0.432249, 0.189957, 0.239649, 0.375693, 0.374637, 0.543097, 0.349762, 0.261214, 0.296163, 0.232716, 0.400368, 0.325576, 0.414453, 0.416526, 0.281677, 0.424760, 0.263459, 0.614032, 0.391415, 0.315304, 0.327117, 0.394267, 0.269381, 0.329754, 0.429086, 0.199661, 0.422053, 0.452569, 0.246464, 0.375692, 0.324811, 0.269992, 0.353887, 0.478460, 0.457539, 0.304475, 0.350290, 0.348431, 0.442602, 0.525911, 0.419034, 0.290074, 0.249516, 0.316269, 0.311577, 0.324696, 0.327152, 0.206537, 0.424409, 0.324708, 0.196245, 0.592893, 0.427110, 0.167560, 0.249076, 0.252712, 0.256961, 0.493125, 0.292637, 0.288409, 0.379452, 0.283015, 0.407624, 0.418018, 0.465182, 0.221875, 0.237757, 0.263632, 0.468349, 0.499961, 0.432698, 0.420996, 0.360633, 0.281473, 0.257874, 0.434252, 0.312197, 0.451346, 0.326823, 0.269862, 0.453124, 0.367169, 0.379352, 0.301301, 0.197925, 0.124989, 0.264033, 0.348791, 0.367305, 0.387211, 0.522218, 0.458287, 0.237339, 0.393557, 0.294964, 0.297004, 0.337845, 0.267591, 0.286094, 0.306340, 0.358737, 0.287091, 0.359428, 0.394696, 0.258891, 0.353241, 0.291400, 0.588031, 0.246602, 0.532488, 0.411316, 0.313089, 0.304064, 0.397415, 0.249475, 0.330778, 0.407016, 0.378881, 0.443070, 0.361640, 0.412381, 0.314457, 0.347157, 0.511132, 0.280886, 0.349706, 0.324805, 0.280468, 0.410394, 0.305351, 0.216179, 0.296858, 0.484862, 0.473575, 0.374156, 0.289820, 0.227066, 0.263143, 0.520773, 0.369732, 0.388833, 0.363224, 0.400892, 0.351034, 0.358700, 0.345579, 0.307472, 0.251734, 0.563745, 0.603389, 0.424108, 0.001550, 0.371443, 0.380434, 0.355304, 0.370814, 0.355286, 0.394746, 0.338117, 0.302394, 0.180776, 0.400479, 0.413734, 0.259098, 0.219308, 0.274847, 0.392848, 0.166929, 0.250864, 0.269982, 0.394170, 0.296563, 0.417997, 0.348385, 0.301364, 0.321909, 0.365151, 0.392425] single_predict inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] single_predict output predict data------------------------------------ [0.389802, 0.597915, 0.905583, 0.311131, 0.495656, 0.689280, 0.153522, 0.717448, 0.632593, 0.425175, 0.107704, 0.656337, 0.254818, 0.614943, 0.501162, 0.043916, 0.560118, 0.478210, 0.206129, 0.262765, 0.262571, 0.175422, 0.108688, 0.679870, 0.846185, 0.457608, 0.190604, 0.572521, 0.116944, 0.730495, 0.432012, 0.958689, 0.231509, 0.391596, 0.073421, 0.543554, 0.607533, 0.963577, 0.488422, 0.609506, 0.085435, 0.208220, 0.444951, 0.086885, 0.609601, 0.281158, 0.545926, 0.439753, 0.409375, 0.546458, 0.759664, 0.763991, 0.426631, 0.160127, 0.520508, 0.308276, 0.542363, 0.761224, 0.798679, 0.868364, 0.333578, 0.548297, 0.682175, 0.587619, 0.125105, 0.324610, 0.902170, 0.133932, 0.083708, 0.449250, 0.230032, 0.349116, 0.099537, 0.812767, 0.730489, 0.174650, 0.315293, 0.141455, 0.250702, 0.164722, 0.654409, 0.187149, 0.248818, 0.240006, 0.636383, 0.812820, 0.508121, 0.607807, 0.346713, 0.194062, 0.199359, 0.276878, 0.522631, 0.405292, 0.290740, 0.129850, 0.384727, 0.296280, 0.742341, 0.547241, 0.414050, 0.303259, 0.765255, 0.385090, 0.218101, 0.639672, 0.422615, 0.244669, 0.188252, 0.185031, 0.252912, 0.824582, 0.658651, 0.762592, 0.588420, 0.312481, 0.863796, 0.054983, 0.437385, 0.637060, 0.501782, 0.815669, 0.550688, 0.752877, 0.613660, 0.751275, 0.396306, 0.373400, 0.543618, 0.389181, 0.506557, 0.304902, 0.729749, 0.060062, 0.389892, 0.531745, 0.527831, 0.456572, 0.112960, 0.331922, 0.217177, 0.836067, 0.102013, 0.230309, 0.758432, 0.364875, 0.196932, 0.930375, 0.298557, 0.088905, 0.533363, 0.189101, 0.113858, 0.455451, 0.316316, 0.385127, 0.097742, 0.452967, 0.732501, 0.306971, 0.857547, 0.650078, 0.132360, 0.883436, 0.536425, 0.002375, 0.487041, 0.553193, 0.491384, 0.175722, 0.110059, 0.281420, 0.239131, 0.391724, 0.413912, 0.403337, 0.552353, 0.594573, 0.700203, 0.701609, 0.149008, 0.123142, 0.223642, 0.155526, 0.529709, 0.423231, 0.403418, 0.858137, 0.504561, 0.146334, 0.401272, 0.821282]

xinsuinizhuan commented 4 years ago

I fell that, batched forward the forward_propagation and predict is same fuction, but forward_propagation fuction is best,so predict is seems unnecessary. But the single_predict is the most common function, because we all predict with one image or one record. So concentrate on the single_predict function.

josephjaspers commented 4 years ago

Hi, I just added:

    /** Copies a single cell-state from the forward_propagation cell-state data to the single predict cell-state.
     *
     * Any existing data in the single_predict cell-state is overwritten.
     * This function is only relevant to the recurrent-neural networks.
     */
    void move_training_data_to_single_predict(int batch_index) {
        m_layer_chain.for_each([&](auto& layer) {
            layer.move_training_data_to_single_predict(batch_index);
        });
    }

To the Neural_Network class. In commit https://github.com/josephjaspers/blackcat_tensors/commit/05a9daf7d3483024872346c39d41f72b0627d418 However, it seems that even when using move_training_data_to_single_predict the results are still not exactly the same (single_predict vs forward_propagation) however they are much closer.

I am still debugging why there is a slight difference between the two.

xinsuinizhuan commented 4 years ago

break!!! use this net auto make_lstm_network() { return BC::nn::neuralnetwork( BC::nn::lstm(BC::host_tag(), 96 * 10, 2048), BC::nn::lstm(BC::host_tag(), 2048, 1024), BC::nn::lstm(BC::host_tag(), 1024, 512), BC::nn::lstm(BC::host_tag(), 512, 216), BC::nn::feedforward(BC::host_tag(), 216, 192), BC::nn::logistic(BC::host_tag(), 192), BC::nn::logging_output_layer(BC::host_tag(), 192, BC::nn::RMSE).skip_every(100) ); } using network_type = decltype(make_lstm_network()); when run to predict, then break!!! 图片

图片

But this net sturct is ok! auto make_lstm_network() { return BC::nn::neuralnetwork( BC::nn::lstm(BC::host_tag(), 96 * 10, 1024), BC::nn::lstm(BC::host_tag(), 1024, 512), BC::nn::lstm(BC::host_tag(), 512, 216), BC::nn::feedforward(BC::host_tag(), 216, 192), BC::nn::logistic(BC::host_tag(), 192), BC::nn::logging_output_layer(BC::host_tag(), 192, BC::nn::RMSE).skip_every(100) ); } using network_type = decltype(make_lstm_network());

I don't know why?

josephjaspers commented 4 years ago

Checking now.

Are you testing in debug mode or release mode?

xinsuinizhuan commented 4 years ago

Checking now.

Are you testing in debug mode or release mode?

I only test release mode. Need to test debug model?

josephjaspers commented 4 years ago

I do not (that would take very long too!). Just checking so I could replicate it

josephjaspers commented 4 years ago

Did it crash on the first call to predict? (Or were there multiple calls to predict before crashing)

(Are you using malloc to create the neuralnetwork?)

xinsuinizhuan commented 4 years ago

Yes, i create the neuralnetwork. Now use the first network struct, every time break, in release mode. I am test debug model.

xinsuinizhuan commented 4 years ago

My example: //start train LstmPredictTask* lstmpredicttask = new LstmPredictTask(); if (lstmpredicttask == NULL) { return -2; }

//LstmPredictTask lstmpredicttask;
std::cout << "Neural Network architecture: \n" << lstmpredicttask->m_pnetwork.get_string_architecture() << std::endl;
lstmpredicttask->m_pnetwork.set_learning_rate(lstmpredicttask->m_learning_rate);
lstmpredicttask->m_pnetwork.set_batch_size(lstmpredicttask->m_batch_size);

int training_sets;
std::pair<cube, cube> data = load_train_data(system_tag, datafilepath, lstmpredicttask, &training_sets);
cube& inputs = data.first;
cube& outputs = data.second;

std::cout <<" training..." << std::endl;
auto start = std::chrono::system_clock::now();

std::cout << "imagesinput real data:------------------------------------" << std::endl;
auto imagesinput = reshape(inputs[0], BC::shape(96, 10, lstmpredicttask->m_batch_size));
imagesinput[0].t().print_sparse();

std::cout << "imagesoutput real data:------------------------------------" << std::endl;
auto imagesoutput = reshape(outputs[0], BC::shape(96, 2, lstmpredicttask->m_batch_size));
imagesoutput[0].t().print_sparse();

for (int i = 0; i < epochs; ++i) {
    std::cout << " current epoch: " << i << std::endl;
    for (int j = 0; j < training_sets; ++j) {
        lstmpredicttask->m_pnetwork.forward_propagation(inputs[j]);
        lstmpredicttask->m_pnetwork.back_propagation(outputs[j]);
        lstmpredicttask->m_pnetwork.update_weights();
    }
}

//if (strlen(_trainparamsavefile) != 0)
//{
//  lstmpredicttask->m_pnetwork.save(_trainparamsavefile); //Uncomment to add saving/loading
//}

auto end = std::chrono::system_clock::now();
clock total = clock(end - start);
std::cout << " training time: " << total.count() << std::endl;

{
    auto batch = inputs[0];
    mat hyps = lstmpredicttask->m_pnetwork.forward_propagation(batch);
    std::cout << " forward_propagation MAPE loss: " << BC::Scalar<double>(BC::nn::MAPE(hyps, outputs[0]) / lstmpredicttask->m_batch_size).data()[0] << std::endl;

    std::cout << "forward_propagation inputdata------------------------------------" << std::endl;
    auto imagesinput0 = reshape(inputs[0], BC::shape(96, 10, lstmpredicttask->m_batch_size));
    imagesinput0[0].t().print_sparse();

    std::cout << "forward_propagation output predict data------------------------------------" << std::endl;
    hyps[0].print();
}

{
    auto batch = inputs[0];
    mat hyps = lstmpredicttask->m_pnetwork.predict(batch);
    std::cout << " predict MAPE loss: " << BC::Scalar<double>(BC::nn::MAPE(hyps, outputs[0]) / lstmpredicttask->m_batch_size).data()[0] << std::endl;

    std::cout << "predict inputdata------------------------------------" << std::endl;
    auto imagesinput0 = reshape(inputs[0], BC::shape(96, 10, lstmpredicttask->m_batch_size));
    imagesinput0[0].t().print_sparse();

    std::cout << "predict output predict data------------------------------------" << std::endl;
    hyps[0].print();
}
josephjaspers commented 4 years ago

I was able to reproduce the error. I will try to fix it soon.

xinsuinizhuan commented 4 years ago

I was able to reproduce the error. I will try to fix it soon.

OK. How about this compile error? https://github.com/josephjaspers/blackcat_tensors/issues/40#issuecomment-547206958

josephjaspers commented 4 years ago

I added a fix for the predict_bug. https://github.com/josephjaspers/blackcat_tensors/commit/2405b820fb68a4c901c4dfdcb6b7b7da97ba2667 It seems the issue only occurs in Visual Studio (not on linux). However, the current version of predict on Windows is now identical to forward_propagation so there isn't much reason to use it.

xinsuinizhuan commented 4 years ago

However, the current version of predict on Windows is now identical to forward_propagation so there isn't much reason to use it.

Yes, now the forward_propagation and the predict function is same. We should focus on the single_predict function.

xinsuinizhuan commented 4 years ago

How about them? Four days have no updata. You are busying in your work?

josephjaspers commented 4 years ago

Oh so sorry!

I thought I had responded to you.

I have been working on this branch: https://github.com/josephjaspers/blackcat_tensors/tree/add_max_pooling (On my local branch, I have a working implementation of MaxPooling, I just need to add the Backwards Implementation and Wrap it in a Neural Network Layer).

josephjaspers commented 4 years ago

Hi!

I just added max_pooling: https://github.com/josephjaspers/blackcat_tensors/issues/49

Next steps: Add GPU support for convolution Add GPU support for max_pooling Add Attention-LSTM Add Optimizers.

Is single-predict still giving different results? (For me the results are slightly different but very close, I am unsure what is causing this). If this is still an issue for you I will prioritize it.

xinsuinizhuan commented 4 years ago

Hi!

I just added max_pooling: #49

Next steps: Add GPU support for convolution Add GPU support for max_pooling Add Attention-LSTM Add Optimizers.

Is single-predict still giving different results? (For me the results are slightly different but very close, I am unsure what is causing this). If this is still an issue for you I will prioritize it.

yes, it is still a problem. Perhaps, you data is samll, but in my data, the litte different will result in the big different output data,this is the three function input and output, the single_predict the output is so big:

Neural Network architecture: LSTM: inputs: 960 outputs: 1024 LSTM: inputs: 1024 outputs: 512 LSTM: inputs: 512 outputs: 216 FeedForward: inputs: 216 outputs: 192 Logistic: inputs: 192 outputs: 192 Output_Layer: inputs: 192 outputs: 192

training... imagesinput real data:------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] imagesoutput real data:------------------------------------ [[0.380952, 0.337662, 0.385281, 0.341991, 0.385281, 0.354978, 0.341991, 0.393939, 0.346320, 0.346320, 0.389610, 0.341991, 0.393939, 0.346320, 0.337662, 0.385281, 0.341991, 0.398268, 0.151515, 0.580087, 0.337662, 0.337662, 0.389610, 0.359307, 0.385281, 0.337662, 0.385281, 0.337662, 0.380952, 0.346320, 0.341991, 0.389610, 0.333333, 0.393939, 0.333333, 0.380952, 0.341991, 0.337662, 0.385281, 0.337662, 0.385281, 0.333333, 0.385281, 0.333333, 0.337662, 0.385281, 0.324675, 0.376623, 0.333333, 0.376623, 0.329004, 0.367965, 0.320346, 0.372294, 0.329004, 0.367965, 0.324675, 0.372294, 0.320346, 0.372294, 0.324675, 0.367965, 0.324675, 0.367965, 0.298701, 0.307359, 0.298701, 0.341991, 0.307359, 0.350649, 0.298701, 0.350649, 0.298701, 0.354978, 0.298701, 0.346320, 0.303030, 0.350649, 0.303030, 0.346320, 0.307359, 0.346320, 0.307359, 0.346320, 0.307359, 0.380952, 0.333333, 0.376623, 0.333333, 0.380952, 0.333333, 0.376623, 0.341991, 0.376623, 0.333333, 0.380952] [0.333333, 0.380952, 0.337662, 0.337662, 0.389610, 0.333333, 0.142857, 0.601732, 0.341991, 0.350649, 0.389610, 0.341991, 0.385281, 0.346320, 0.337662, 0.385281, 0.333333, 0.393939, 0.333333, 0.393939, 0.341991, 0.350649, 0.346320, 0.406926, 0.346320, 0.333333, 0.147186, 0.580087, 0.389610, 0.333333, 0.350649, 0.380952, 0.341991, 0.393939, 0.341991, 0.389610, 0.346320, 0.389610, 0.346320, 0.385281, 0.341991, 0.393939, 0.341991, 0.389610, 0.346320, 0.380952, 0.337662, 0.385281, 0.333333, 0.380952, 0.341991, 0.333333, 0.376623, 0.376623, 0.333333, 0.376623, 0.329004, 0.341991, 0.380952, 0.385281, 0.337662, 0.337662, 0.385281, 0.337662, 0.385281, 0.337662, 0.385281, 0.341991, 0.389610, 0.341991, 0.385281, 0.333333, 0.385281, 0.337662, 0.380952, 0.337662, 0.385281, 0.341991, 0.380952, 0.337662, 0.376623, 0.303030, 0.346320, 0.307359, 0.350649, 0.307359, 0.346320, 0.303030, 0.350649, 0.329004, 0.372294, 0.303030, 0.341991, 0.303030, 0.337662, 0.303030]]

forward_propagation MAPE loss: 0.0321961 forward_propagation inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]]

forward_propagation output predict data------------------------------------ [0.404532, 0.476134, 0.328844, 0.426934, 0.270703, 0.458620, 0.341076, 0.334768, 0.468806, 0.304100, 0.976856, 0.315633, 0.420080, 0.062515, 0.243529, 0.221136, 0.404333, 0.379710, 0.440636, 0.356465, 0.120904, 0.036068, 0.287320, 0.374484, 0.300462, 0.397743, 0.207402, 0.226563, 0.257873, 0.267984, 0.551578, 0.347965, 0.331977, 0.301363, 0.041548, 0.290635, 0.394845, 0.437435, 0.613412, 0.169130, 0.484039, 0.268901, 0.247796, 0.325273, 0.417024, 0.366171, 0.216229, 0.467116, 0.398497, 0.199903, 0.391816, 0.418512, 0.197550, 0.502423, 0.497733, 0.267942, 0.299187, 0.030697, 0.022763, 0.084156, 0.413798, 0.255841, 0.337230, 0.331109, 0.200119, 0.338898, 0.239567, 0.228124, 0.326233, 0.135143, 0.378555, 0.342498, 0.543321, 0.506673, 0.310593, 0.307301, 0.291471, 0.306668, 0.158296, 0.426466, 0.378716, 0.427893, 0.497239, 0.209331, 0.458289, 0.330763, 0.121166, 0.311828, 0.263204, 0.174951, 0.280308, 0.349826, 0.421087, 0.453572, 0.998139, 0.303182, 0.380307, 0.249877, 0.347014, 0.343653, 0.418789, 0.364514, 0.335602, 0.377660, 0.484541, 0.346718, 0.122578, 0.509846, 0.297208, 0.381712, 0.374355, 0.422839, 0.069437, 0.174575, 0.378778, 0.348939, 0.536228, 0.225626, 0.254839, 0.407721, 0.317331, 0.186351, 0.270692, 0.206890, 0.226693, 0.370466, 0.377802, 0.447872, 0.448710, 0.987046, 0.461847, 0.453755, 0.316290, 0.409107, 0.354594, 0.319344, 0.329214, 0.058827, 0.341980, 0.350576, 0.302538, 0.213071, 0.184631, 0.333771, 0.314654, 0.317235, 0.201519, 0.231514, 0.097600, 0.303638, 0.133759, 0.325704, 0.324617, 0.299926, 0.395951, 0.375271, 0.415061, 0.371168, 0.322983, 0.275537, 0.287059, 0.269967, 0.374939, 0.985260, 0.405212, 0.203774, 0.523954, 0.487154, 0.377398, 0.342924, 0.252147, 0.334527, 0.007345, 0.344799, 0.329339, 0.336787, 0.304066, 0.137525, 0.996247, 0.339163, 0.344513, 0.262848, 0.261368, 0.257797, 0.492305, 0.242995, 0.412278, 0.093704, 0.315248, 0.385679, 0.359690, 0.257847] predict MAPE loss: 0.0332913

predict inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]]

predict output predict data------------------------------------ [0.374448, 0.564394, 0.354690, 0.465649, 0.266165, 0.465072, 0.316884, 0.369836, 0.471120, 0.262616, 0.969885, 0.379447, 0.466348, 0.098868, 0.260437, 0.215778, 0.422643, 0.376386, 0.429980, 0.348787, 0.176748, 0.051450, 0.297591, 0.359938, 0.301524, 0.443577, 0.257440, 0.227588, 0.261359, 0.309156, 0.539591, 0.351309, 0.358653, 0.332885, 0.035002, 0.394732, 0.445727, 0.424722, 0.619815, 0.166040, 0.447071, 0.220799, 0.291700, 0.325852, 0.494173, 0.414733, 0.249242, 0.529702, 0.373063, 0.210945, 0.440381, 0.458675, 0.218825, 0.495972, 0.517043, 0.313683, 0.337259, 0.039814, 0.031061, 0.082349, 0.428211, 0.316373, 0.413368, 0.398967, 0.219137, 0.344268, 0.279681, 0.234534, 0.375327, 0.105598, 0.400401, 0.366577, 0.520584, 0.511029, 0.310017, 0.345755, 0.328848, 0.406407, 0.140360, 0.484105, 0.412681, 0.377267, 0.541156, 0.214963, 0.470064, 0.391646, 0.138953, 0.253455, 0.384632, 0.154159, 0.234578, 0.361021, 0.466081, 0.453252, 0.997117, 0.273981, 0.344303, 0.287551, 0.369627, 0.371841, 0.405163, 0.434129, 0.379777, 0.382589, 0.607213, 0.438022, 0.132803, 0.511256, 0.337725, 0.419296, 0.409784, 0.404004, 0.068436, 0.154254, 0.323000, 0.381302, 0.557047, 0.276301, 0.246935, 0.418109, 0.318567, 0.245762, 0.350656, 0.275067, 0.201043, 0.449263, 0.362461, 0.453186, 0.415039, 0.986109, 0.460475, 0.398715, 0.324316, 0.374510, 0.354861, 0.375441, 0.344285, 0.060621, 0.362904, 0.331070, 0.312567, 0.217825, 0.187484, 0.343767, 0.331342, 0.270052, 0.187798, 0.300415, 0.098467, 0.358481, 0.146371, 0.278904, 0.400584, 0.367681, 0.417191, 0.345485, 0.394248, 0.358052, 0.344753, 0.253110, 0.251398, 0.323197, 0.417996, 0.982758, 0.386676, 0.183951, 0.458941, 0.422835, 0.355686, 0.292942, 0.268050, 0.394068, 0.008641, 0.391218, 0.325866, 0.348157, 0.335110, 0.126606, 0.995094, 0.268525, 0.336710, 0.241941, 0.234148, 0.221454, 0.550659, 0.290820, 0.422846, 0.110844, 0.331951, 0.465628, 0.324932, 0.326787]

single_predict inputdata------------------------------------ [[ , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , 0.235065, 0.235065, 0.390909, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.450216, 0.386147, 0.409091, 0.422944, 0.418182, 0.436364, 0.390909, 0.336364, 0.458874, 0.463636, 0.381818, 0.436364, 0.381818, 0.390909, 0.436364, 0.409091, 0.381818, 0.441126, 0.431602, 0.352381, 0.352381, 0.427273, 0.431602, 0.404762, 0.327273, 0.540693, 0.472727, 0.409091, 0.472727, 0.409091, 0.468398, 0.413420, 0.463636, 0.409091, 0.468398, 0.409091, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.422944, 0.467965, 0.409091, 0.477489, 0.409091, 0.418182, 0.463636, 0.413420, 0.463636, 0.413853, 0.404329, 0.404762, 0.472727] [0.463636, 0.290909, 0.422511, 0.254545, 0.763636, 0.436364, 0.368398, 0.231602, 0.695671, 0.349784, 0.550216, 0.413420, 0.350216, 0.536364, 0.400000, 0.463636, 0.363636, 0.363636, 0.518182, 0.458874, 0.413853, 0.463636, 0.409091, 0.418182, 0.467965, 0.422944, 0.400000, 0.467965, 0.418182, 0.409091, 0.409091, 0.432035, 0.463636, 0.400000, 0.472727, 0.400000, 0.400000, 0.458874, 0.427273, 0.422944, 0.463636, 0.349784, 0.209091, 0.558874, 0.558874, 0.418182, 0.400000, 0.468398, 0.109091, 0.727273, 0.413420, 0.404762, 0.458874, 0.404762, 0.400000, 0.395238, 0.450216, 0.454545, 0.436364, 0.395238, 0.395671, 0.436364, 0.458874, 0.318182, 0.461472, 0.461472, 0.386147, 0.386147, 0.386147, 0.590909, 0.359307, 0.440693, 0.440693, 0.413420, 0.427273, 0.241126, 0.618182, 0.381818, 0.390909, 0.518182, 0.240693, 0.627273, 0.304762, 0.386147, 0.541126, 0.427273, 0.415584, 0.415584, 0.359307, 0.531602, 0.354545, 0.438528, 0.438528, 0.418182, 0.449784, 0.390909] [0.413853, 0.336364, 0.531602, 0.409091, 0.345455, 0.461472, 0.461472, 0.463636, 0.345455, 0.513420, 0.400000, 0.459307, 0.286147, 0.447619, 0.447619, 0.531602, 0.400000, 0.450216, 0.418182, 0.340693, 0.495671, 0.463636, 0.404329, 0.454545, 0.400000, 0.354545, 0.470563, 0.470563, 0.409091, 0.463636, 0.372727, 0.372727, 0.400000, 0.463636, 0.413420, 0.404762, 0.463636, 0.400000, 0.463636, 0.413420, 0.463636, 0.404762, 0.467965, 0.409091, 0.468398, 0.413420, 0.418182, 0.463636, 0.404762, 0.463636, 0.427273, 0.409091, 0.467965, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.418182, 0.468398, 0.413420, 0.468398, 0.404329, 0.463636, 0.409091, 0.463636, 0.422944, 0.354545, 0.490909, 0.304329, 0.495671, 0.456710, 0.456710, 0.481818, 0.404762, 0.427273, 0.113420, 0.786580, 0.404329, 0.477489, 0.404329, 0.413853, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.409091, 0.463636, 0.404329, 0.463636, 0.404762, 0.463636, 0.404329, 0.468398] [0.409091, 0.422511, 0.459307, 0.409091, 0.458874, 0.409091, 0.459307, 0.404329, 0.463636, 0.404762, 0.458874, 0.404762, 0.467965, 0.404762, 0.463636, 0.409091, 0.467965, 0.409091, 0.472727, 0.409091, 0.232035, 0.649784, 0.413853, 0.463636, 0.409091, 0.467965, 0.409091, 0.463636, 0.409091, 0.468398, 0.409091, 0.467965, 0.404762, 0.472727, 0.404329, 0.481818, 0.404762, 0.295238, 0.536364, 0.477489, 0.236364, 0.649784, 0.422944, 0.472727, 0.349784, 0.490909, 0.463636, 0.413853, 0.477056, 0.404762, 0.404329, 0.404762, 0.463636, 0.467965, 0.400000, 0.463636, 0.404762, 0.463636, 0.404329, 0.459307, 0.404329, 0.459307, 0.404329, 0.459307, 0.340693, 0.536364, 0.400000, 0.459307, 0.400000, 0.400000, 0.467965, 0.390909, 0.450216, 0.386147, 0.441126, 0.386147, 0.441126, 0.390909, 0.449784, 0.395671, 0.445455, 0.395238, 0.450216, 0.395238, 0.450216, 0.400000, 0.454545, 0.400000, 0.449784, 0.400000, 0.454545, 0.404762, 0.467965, 0.404762, 0.413420, 0.463636] [0.400000, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.450216, 0.400000, 0.458874, 0.400000, 0.454545, 0.400000, 0.454545, 0.395671, 0.454545, 0.400000, 0.454545, 0.395238, 0.459307, 0.395238, 0.459307, 0.404329, 0.454545, 0.345455, 0.513853, 0.395238, 0.459307, 0.400000, 0.454545, 0.413420, 0.404762, 0.458874, 0.345455, 0.467965, 0.467965, 0.432035, 0.231602, 0.590909, 0.395671, 0.440693, 0.400000, 0.272727, 0.418182, 0.559307, 0.386147, 0.272727, 0.527273, 0.381818, 0.438528, 0.438528, 0.404329, 0.386580, 0.436364, 0.377056, 0.432035, 0.386147, 0.432035, 0.377056, 0.436364, 0.381818, 0.441126, 0.381818, 0.436364, 0.381818, 0.436364, 0.377056, 0.436364, 0.377489, 0.431602, 0.381818, 0.432035, 0.377056, 0.432035, 0.381818, 0.427273, 0.381818, 0.427273, 0.377056, 0.432035, 0.377056, 0.432035, 0.377056, 0.432035, 0.395238, 0.322944, 0.286147, 0.254545, 0.732035, 0.309091, 0.563636, 0.386147, 0.190909, 0.432035, 0.500000, 0.500000, 0.404329] [0.436364, 0.404762, 0.440693, 0.413853, 0.390909, 0.200000, 0.451515, 0.451515, 0.451515, 0.451515, 0.451515, 0.436364, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.445455, 0.386147, 0.441126, 0.390909, 0.440693, 0.390909, 0.441126, 0.395238, 0.441126, 0.386147, 0.445455, 0.386580, 0.418182, 0.431602, 0.418182, 0.395671, 0.458874, 0.395671, 0.449784, 0.400000, 0.450216, 0.400000, 0.404329, 0.472727, 0.395671, 0.458874, 0.395671, 0.404329, 0.454545, 0.390909, 0.445455, 0.395671, 0.445455, 0.395238, 0.454545, 0.390909, 0.450216, 0.390909, 0.449784, 0.395671, 0.449784, 0.390909, 0.395671, 0.463636, 0.400000, 0.427273, 0.449784, 0.427273, 0.409091, 0.277489, 0.545455, 0.445455, 0.390909, 0.177056, 0.672727, 0.336364, 0.481818, 0.395671, 0.454545, 0.400000, 0.454545, 0.454545, 0.395238, 0.390909, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216] [0.390909, 0.449784, 0.390909, 0.341126, 0.518182, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.395238, 0.450216, 0.395238, 0.450216, 0.390909, 0.449784, 0.390909, 0.450216, 0.390909, 0.449784, 0.400000, 0.450216, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.386580, 0.449784, 0.390909, 0.445455, 0.390909, 0.445455, 0.390909, 0.450216, 0.390909, 0.449784, 0.390909, 0.445455, 0.395671, 0.440693, 0.386580, 0.440693, 0.381818, 0.445455, 0.386580, 0.440693, 0.386580, 0.440693, 0.390909, 0.441126, 0.381818, 0.440693, 0.381818, 0.441126, 0.386147, 0.395671, 0.413420, 0.450216, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.409091, 0.386147, 0.441126, 0.386147, 0.445455, 0.390909, 0.441126, 0.381818, 0.440693, 0.386580, 0.440693, 0.386580, 0.440693, 0.218182, 0.497835, 0.497835, 0.381818, 0.436364, 0.381818, 0.431602, 0.381818, 0.441126, 0.381818, 0.427273, 0.386147, 0.163636, ] [1.000000, 0.432035, 0.427273, 0.377056, 0.427273, 0.368398, 0.418182, 0.386147, 0.368398, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.418182, 0.363636, 0.413420, 0.363636, 0.418182, 0.363636, 0.413853, 0.345455, 0.413420, 0.363636, 0.409091, 0.359307, 0.409091, 0.354545, 0.404329, 0.354545, 0.409091, 0.354545, 0.409091, 0.336364, 0.377489, 0.327273, 0.377056, 0.327273, 0.286580, 0.431602, 0.332035, 0.381818, 0.331602, 0.372727, 0.313853, 0.358874, 0.309091, 0.350216, 0.309091, 0.349784, 0.309091, 0.350216, 0.304329, 0.359307, 0.304329, 0.354545, 0.304762, 0.349784, 0.313853, 0.354545, 0.329437, 0.329437, 0.313853, 0.358874, 0.313853, 0.358874, 0.313853, 0.354545, 0.309091, 0.358874, 0.309091, 0.386580, 0.331602, 0.377489, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.332035, 0.377056, 0.332035, 0.281818, 0.436364, 0.331602, 0.381818, 0.332035, 0.377056] [0.332035, 0.377056, 0.345455, 0.381818, 0.332035, 0.381818, 0.331602, 0.377489, 0.331602, 0.377489, 0.331602, 0.377489, 0.327273, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.332035, 0.377056, 0.336364, 0.327273, 0.390909, 0.332035, 0.377056, 0.327273, 0.377489, 0.327273, 0.377056, 0.327273, 0.381818, 0.332035, 0.381818, 0.331602, 0.381818, 0.336364, 0.359307, 0.309091, 0.218182, 0.436364, 0.263636, 0.367965, 0.367965, 0.354545, 0.318182, 0.304329, 0.359307, 0.318182, 0.304329, 0.350216, 0.304329, 0.354545, 0.318182, 0.304762, 0.367965, 0.304762, 0.345455, 0.309091, 0.313420, 0.354545, 0.309091, 0.350216, 0.309091, 0.354545, 0.309091, 0.349784, 0.313853, 0.354545, 0.267965, , , , , , , , , , , , 0.335931, 0.341126, 0.386147, 0.336364, 0.381818, 0.336364, 0.381818, 0.332035, 0.390909, 0.340693, 0.386580, 0.336364, 0.386147, 0.332035, 0.390909, 0.336364] [0.386147, 0.336364, 0.386580, 0.336364, 0.381818, 0.331602, 0.386580, 0.331602, 0.381818, 0.336364, 0.386580, 0.336364, 0.286147, 0.436364, 0.336364, 0.386580, 0.336364, 0.386147, 0.386580, 0.340693, 0.341126, 0.390909, 0.354545, 0.340693, 0.390909, 0.350216, 0.381818, 0.336364, 0.381818, 0.313420, 0.395671, 0.331602, 0.386580, 0.140693, 0.563636, 0.400000, 0.336364, , 0.609091, 0.395671, 0.345455, 0.395238, 0.350216, 0.400000, 0.336364, 0.386147, 0.336364, 0.336364, 0.381818, 0.332035, 0.381818, 0.349784, 0.336364, 0.390909, 0.329870, 0.385281, 0.329004, 0.341991, 0.367965, 0.380952, 0.346320, 0.324675, 0.354978, 0.307359, 0.359307, 0.320346, 0.311688, 0.354978, 0.311688, 0.341991, 0.303030, 0.350649, 0.311688, 0.346320, 0.316017, 0.354978, 0.311688, 0.354978, 0.303030, 0.354978, 0.311688, 0.367965, 0.346320, 0.316017, 0.103896, 0.584416, 0.467532, 0.329004, 0.389610, 0.238095, 0.378788, 0.378788, 0.298701, 0.419913, 0.419913, 0.333333]] single_predict output predict data------------------------------------ [0.715823, 0.786678, 0.664938, 0.224606, 0.330309, 0.594333, 0.702704, 0.726132, 0.223699, 0.176048, 0.892282, 0.798745, 0.338442, 0.614723, 0.269472, 0.089377, 0.383681, 0.881474, 0.719474, 0.764565, 0.710597, 0.649136, 0.492436, 0.578732, 0.410917, 0.819644, 0.580089, 0.236140, 0.303525, 0.732678, 0.594589, 0.172383, 0.356966, 0.481186, 0.119218, 0.890867, 0.524581, 0.714259, 0.851835, 0.081749, 0.799770, 0.193316, 0.860742, 0.109847, 0.839480, 0.825123, 0.739441, 0.780463, 0.451633, 0.472697, 0.763012, 0.841725, 0.643023, 0.484841, 0.821831, 0.390518, 0.577800, 0.183055, 0.085486, 0.385336, 0.440619, 0.638140, 0.681771, 0.753230, 0.425089, 0.547126, 0.413822, 0.185765, 0.674269, 0.110794, 0.389893, 0.436003, 0.547286, 0.747170, 0.426069, 0.340636, 0.461366, 0.872608, 0.170463, 0.821453, 0.686799, 0.197970, 0.832788, 0.123749, 0.890811, 0.788550, 0.230534, 0.150522, 0.665755, 0.178344, 0.093394, 0.190971, 0.685109, 0.473930, 0.937859, 0.180150, 0.372106, 0.180841, 0.559759, 0.724854, 0.701096, 0.846477, 0.818112, 0.427110, 0.832867, 0.707348, 0.329092, 0.720792, 0.328781, 0.434321, 0.841834, 0.368064, 0.267720, 0.207651, 0.113088, 0.197918, 0.910166, 0.167526, 0.119777, 0.895569, 0.469350, 0.529826, 0.547047, 0.549895, 0.259217, 0.740872, 0.570477, 0.591771, 0.247182, 0.886635, 0.517501, 0.392819, 0.657301, 0.164558, 0.412941, 0.244990, 0.201201, 0.139451, 0.727651, 0.624378, 0.198868, 0.453494, 0.047936, 0.320337, 0.693498, 0.746599, 0.138575, 0.835170, 0.128552, 0.837345, 0.086434, 0.260558, 0.827286, 0.677093, 0.449687, 0.141088, 0.236419, 0.922765, 0.541133, 0.140580, 0.518347, 0.822704, 0.852684, 0.937660, 0.491249, 0.133104, 0.148482, 0.669517, 0.769244, 0.143638, 0.456581, 0.691282, 0.080960, 0.281191, 0.138474, 0.514065, 0.738914, 0.206140, 0.918884, 0.151448, 0.294774, 0.186677, 0.637583, 0.304582, 0.887552, 0.420795, 0.733054, 0.523726, 0.682746, 0.587815, 0.121862, 0.605980]

xinsuinizhuan commented 4 years ago

Hi!

I just added max_pooling: #49

Next steps: Add GPU support for convolution Add GPU support for max_pooling Add Attention-LSTM Add Optimizers.

Is single-predict still giving different results? (For me the results are slightly different but very close, I am unsure what is causing this). If this is still an issue for you I will prioritize it. I only train 1028 epochs, perhaps set epoch to 5000 the result is better, but the single_predict output still so big than others. the first red box is the real output data, the second is the forward_propagation outdata, the third is predict outdata, the forth is the single_predict output data: 图片 I think you shoud print the every calculating parameter, to find where is different, then find what result it, then to fix it, when calculate the output.

xinsuinizhuan commented 4 years ago

Hi!

I just added max_pooling: #49

Next steps: Add GPU support for convolution Add GPU support for max_pooling Add Attention-LSTM Add Optimizers.

Is single-predict still giving different results? (For me the results are slightly different but very close, I am unsure what is causing this). If this is still an issue for you I will prioritize it.

I think you shoul be the next steps: single_predict output fix Add Optimizers. Add GPU support for convolution Add GPU support for max_pooling Add Attention-LSTM

josephjaspers commented 4 years ago

I found the issue with single_predict, some of the data wasn't being copied correctly in recurrent layers. https://github.com/josephjaspers/blackcat_tensors/commit/7d284874ff9e874cb86d12a41b20c7cf81c729ee

xinsuinizhuan commented 4 years ago

I found the issue with single_predict, some of the data wasn't being copied correctly in recurrent layers. 7d28487

I test it, it seems NO:

图片

josephjaspers commented 4 years ago

Did you use copy_training_data_to_single_predict?

Example:

network.copy_training_data_to_single_predict(batch_index); 
xinsuinizhuan commented 4 years ago

copy_training_data_to_single_predict

last test, i have not use the copy_training_data_to_single_predict, so it seems not work. this time, i use the copy_training_data_to_single_predict function, the result not same as the forward_propagation, but them so closed, it also have something wrong with it, but i am puzzled ,whether them should be the same, or it so close is right? 图片

other problem, what the copy_training_data_to_single_predict function's parameter mean, 0 is ok? 图片

josephjaspers commented 4 years ago

I believe single_predict is working:

Added minor changes as of https://github.com/josephjaspers/blackcat_tensors/commit/b812d74ea4b7b68a21a277e8821c517d327f35d3

The batch_index determines which batch is copied from the training data to the prediction data.

Test: mnist_test_recurrent.

    network.copy_training_data_to_single_predict(0);
    {
        BC::size_t test_images = 10;
        cube img = cube(reshape(inputs[0], BC::shape(28,28, batch_size)));
        for (int i = 0; i < test_images; ++i) {
            auto batch = inputs[i];
            auto shape = BC::shape(784/4, batch_size);
            for (int p = 0; p < img_partitions-1; ++p) {
                auto index = BC::index(0,784 * (p/(float)img_partitions));
                network.predict(batch[{index, shape}]);
            }
            auto last_index = BC::index(0,784 * ((img_partitions-1)/(float)img_partitions));
            vec hyps = network.predict(batch[{last_index, shape}]).slice(0);
            hyps.print();
            std::cout << "------------------------------------" <<std::endl;
        }
    }
    BC::print("----");
    {

        BC::size_t test_images = 10;
        for (int i = 0; i < test_images; ++i) {

            auto batch = inputs[i];
            auto shape = BC::shape(784/4, batch_size);
            for (int p = 0; p < img_partitions-1; ++p) {
                auto index = BC::index(0,784 * (p/(float)img_partitions));
                network.single_predict(batch[{index, shape}][0]);
            }
            auto last_index = BC::index(0,784 * ((img_partitions-1)/(float)img_partitions));
            vec hyps = network.single_predict(batch[{last_index, shape}][0]);

            hyps.print();
            std::cout << "------------------------------------" <<std::endl;
        }
    }

Output:

Batch index: 20400 loss: [0.068659]
 training time: 221.102
 testing... 
[0.000143, 0.996308, 0.000511, 0.000001, 0.000005, 0.000033, 0.000007, 0.001364, 0.001567, 0.000061]
------------------------------------
[0.000037, 0.000025, 0.001304, 0.000030, 0.990513, 0.001262, 0.000524, 0.002080, 0.000039, 0.004186]
------------------------------------
[0.000005, 0.000000, 0.000026, 0.000001, 0.000582, 0.000009, 0.999119, 0.000006, 0.000007, 0.000244]
------------------------------------
[0.000000, 0.996579, 0.000211, 0.000008, 0.000169, 0.000000, 0.000005, 0.002914, 0.000028, 0.000086]
------------------------------------
[0.001025, 0.000002, 0.007655, 0.003944, 0.000641, 0.956958, 0.028563, 0.000000, 0.000848, 0.000362]
------------------------------------
[0.000000, 0.000002, 0.007094, 0.000015, 0.000296, 0.000012, 0.992436, 0.000000, 0.000140, 0.000004]
------------------------------------
[0.000000, 0.000019, 0.000470, 0.983806, 0.000002, 0.000458, 0.000000, 0.013112, 0.001382, 0.000750]
------------------------------------
[0.000001, 0.999644, 0.000006, 0.000000, 0.000059, 0.000005, 0.000024, 0.000026, 0.000223, 0.000012]
------------------------------------
[0.000000, 0.999957, 0.000003, 0.000015, 0.000011, 0.000000, 0.000009, 0.000005, 0.000000, 0.000001]
------------------------------------
[0.999835, 0.000000, 0.000000, 0.000001, 0.000044, 0.000005, 0.000001, 0.000000, 0.000019, 0.000096]
------------------------------------
----
[0.000143, 0.996308, 0.000511, 0.000001, 0.000005, 0.000033, 0.000007, 0.001364, 0.001567, 0.000061]
------------------------------------
[0.000037, 0.000025, 0.001304, 0.000030, 0.990513, 0.001262, 0.000524, 0.002080, 0.000039, 0.004186]
------------------------------------
[0.000005, 0.000000, 0.000026, 0.000001, 0.000582, 0.000009, 0.999119, 0.000006, 0.000007, 0.000244]
------------------------------------
[0.000000, 0.996579, 0.000211, 0.000008, 0.000169, 0.000000, 0.000005, 0.002914, 0.000028, 0.000086]
------------------------------------
[0.001025, 0.000002, 0.007655, 0.003944, 0.000641, 0.956958, 0.028563, 0.000000, 0.000848, 0.000362]
------------------------------------
[0.000000, 0.000002, 0.007094, 0.000015, 0.000296, 0.000012, 0.992436, 0.000000, 0.000140, 0.000004]
------------------------------------
[0.000000, 0.000019, 0.000470, 0.983806, 0.000002, 0.000458, 0.000000, 0.013112, 0.001382, 0.000750]
------------------------------------
[0.000001, 0.999644, 0.000006, 0.000000, 0.000059, 0.000005, 0.000024, 0.000026, 0.000223, 0.000012]
------------------------------------
[0.000000, 0.999957, 0.000003, 0.000015, 0.000011, 0.000000, 0.000009, 0.000005, 0.000000, 0.000001]
------------------------------------
[0.999835, 0.000000, 0.000000, 0.000001, 0.000044, 0.000005, 0.000001, 0.000000, 0.000019, 0.000096]
------------------------------------
 success 
joseph@joseph-F570UD:~/BlackCat_Tensors/BlackCat_Tensors/examples/mnist_test_recurrent$ 
xinsuinizhuan commented 4 years ago

I believe single_predict is working:

Added minor changes as of b812d74

The batch_index determines which batch is copied from the training data to the prediction data.

Test: mnist_test_recurrent.

  network.copy_training_data_to_single_predict(0);
  {
      BC::size_t test_images = 10;
      cube img = cube(reshape(inputs[0], BC::shape(28,28, batch_size)));
      for (int i = 0; i < test_images; ++i) {
          auto batch = inputs[i];
          auto shape = BC::shape(784/4, batch_size);
          for (int p = 0; p < img_partitions-1; ++p) {
              auto index = BC::index(0,784 * (p/(float)img_partitions));
              network.predict(batch[{index, shape}]);
          }
          auto last_index = BC::index(0,784 * ((img_partitions-1)/(float)img_partitions));
          vec hyps = network.predict(batch[{last_index, shape}]).slice(0);
          hyps.print();
          std::cout << "------------------------------------" <<std::endl;
      }
  }
  BC::print("----");
  {

      BC::size_t test_images = 10;
      for (int i = 0; i < test_images; ++i) {

          auto batch = inputs[i];
          auto shape = BC::shape(784/4, batch_size);
          for (int p = 0; p < img_partitions-1; ++p) {
              auto index = BC::index(0,784 * (p/(float)img_partitions));
              network.single_predict(batch[{index, shape}][0]);
          }
          auto last_index = BC::index(0,784 * ((img_partitions-1)/(float)img_partitions));
          vec hyps = network.single_predict(batch[{last_index, shape}][0]);

          hyps.print();
          std::cout << "------------------------------------" <<std::endl;
      }
  }

Output:

Batch index: 20400 loss: [0.068659]
 training time: 221.102
 testing... 
[0.000143, 0.996308, 0.000511, 0.000001, 0.000005, 0.000033, 0.000007, 0.001364, 0.001567, 0.000061]
------------------------------------
[0.000037, 0.000025, 0.001304, 0.000030, 0.990513, 0.001262, 0.000524, 0.002080, 0.000039, 0.004186]
------------------------------------
[0.000005, 0.000000, 0.000026, 0.000001, 0.000582, 0.000009, 0.999119, 0.000006, 0.000007, 0.000244]
------------------------------------
[0.000000, 0.996579, 0.000211, 0.000008, 0.000169, 0.000000, 0.000005, 0.002914, 0.000028, 0.000086]
------------------------------------
[0.001025, 0.000002, 0.007655, 0.003944, 0.000641, 0.956958, 0.028563, 0.000000, 0.000848, 0.000362]
------------------------------------
[0.000000, 0.000002, 0.007094, 0.000015, 0.000296, 0.000012, 0.992436, 0.000000, 0.000140, 0.000004]
------------------------------------
[0.000000, 0.000019, 0.000470, 0.983806, 0.000002, 0.000458, 0.000000, 0.013112, 0.001382, 0.000750]
------------------------------------
[0.000001, 0.999644, 0.000006, 0.000000, 0.000059, 0.000005, 0.000024, 0.000026, 0.000223, 0.000012]
------------------------------------
[0.000000, 0.999957, 0.000003, 0.000015, 0.000011, 0.000000, 0.000009, 0.000005, 0.000000, 0.000001]
------------------------------------
[0.999835, 0.000000, 0.000000, 0.000001, 0.000044, 0.000005, 0.000001, 0.000000, 0.000019, 0.000096]
------------------------------------
----
[0.000143, 0.996308, 0.000511, 0.000001, 0.000005, 0.000033, 0.000007, 0.001364, 0.001567, 0.000061]
------------------------------------
[0.000037, 0.000025, 0.001304, 0.000030, 0.990513, 0.001262, 0.000524, 0.002080, 0.000039, 0.004186]
------------------------------------
[0.000005, 0.000000, 0.000026, 0.000001, 0.000582, 0.000009, 0.999119, 0.000006, 0.000007, 0.000244]
------------------------------------
[0.000000, 0.996579, 0.000211, 0.000008, 0.000169, 0.000000, 0.000005, 0.002914, 0.000028, 0.000086]
------------------------------------
[0.001025, 0.000002, 0.007655, 0.003944, 0.000641, 0.956958, 0.028563, 0.000000, 0.000848, 0.000362]
------------------------------------
[0.000000, 0.000002, 0.007094, 0.000015, 0.000296, 0.000012, 0.992436, 0.000000, 0.000140, 0.000004]
------------------------------------
[0.000000, 0.000019, 0.000470, 0.983806, 0.000002, 0.000458, 0.000000, 0.013112, 0.001382, 0.000750]
------------------------------------
[0.000001, 0.999644, 0.000006, 0.000000, 0.000059, 0.000005, 0.000024, 0.000026, 0.000223, 0.000012]
------------------------------------
[0.000000, 0.999957, 0.000003, 0.000015, 0.000011, 0.000000, 0.000009, 0.000005, 0.000000, 0.000001]
------------------------------------
[0.999835, 0.000000, 0.000000, 0.000001, 0.000044, 0.000005, 0.000001, 0.000000, 0.000019, 0.000096]
------------------------------------
 success 
joseph@joseph-F570UD:~/BlackCat_Tensors/BlackCat_Tensors/examples/mnist_test_recurrent$ 

Yes, now the predic and single_predict function, the same input and same output, but compare to forward_propagation function, them is not. It should be this, or still something wrong with it? Test: mnist_test_recurrent. BC::print("forward_propagation ----"); { auto batch = inputs[0]; auto shape = BC::shape(784 / 4, batch_size); for (int p = 0; p < img_partitions - 1; ++p) { auto index = BC::index(0, 784 (p / (float)img_partitions)); network.forward_propagation(batch[{index, shape}]); } auto last_index = BC::index(0, 784 ((img_partitions - 1) / (float)img_partitions)); mat hyps = network.forward_propagation(batch[{last_index, shape}]);

    BC::size_t test_images = 10;
    //cube img = cube(reshape(inputs[0], BC::shape(28, 28, batch_size)));
    for (int i = 0; i < test_images; ++i) {
        //img[i].t().print_sparse(3);
        hyps[i].print();
        std::cout << "------------------------------------" << std::endl;
    }
}

network.copy_training_data_to_single_predict(0);
BC::print("predict ----");
{
    BC::size_t test_images = 10;
    cube img = cube(reshape(inputs[0], BC::shape(28, 28, batch_size)));
    for (int i = 0; i < test_images; ++i) {
        auto batch = inputs[i];
        auto shape = BC::shape(784 / 4, batch_size);
        for (int p = 0; p < img_partitions - 1; ++p) {
            auto index = BC::index(0, 784 * (p / (float)img_partitions));
            network.predict(batch[{index, shape}]);
        }
        auto last_index = BC::index(0, 784 * ((img_partitions - 1) / (float)img_partitions));
        vec hyps = network.predict(batch[{last_index, shape}]).slice(0);
        hyps.print();
        std::cout << "------------------------------------" << std::endl;
    }
}
BC::print("single_predict ----");
{

    BC::size_t test_images = 10;
    for (int i = 0; i < test_images; ++i) {

        auto batch = inputs[i];
        auto shape = BC::shape(784 / 4, batch_size);
        for (int p = 0; p < img_partitions - 1; ++p) {
            auto index = BC::index(0, 784 * (p / (float)img_partitions));
            network.single_predict(batch[{index, shape}][0]);
        }
        auto last_index = BC::index(0, 784 * ((img_partitions - 1) / (float)img_partitions));
        vec hyps = network.single_predict(batch[{last_index, shape}][0]);

        hyps.print();
        std::cout << "------------------------------------" << std::endl;
    }
}

图片

josephjaspers commented 4 years ago

In windows predict and forward_propagation use the same cell-state so they should be different if you run them one after another.

Do you need/want predict to have a seperate cell-state from forward_propagation?

xinsuinizhuan commented 4 years ago

In windows predict and forward_propagation use the same cell-state so they should be different if you run them one after another.

Do you need/want predict to have a seperate cell-state from forward_propagation?

oh. It means it should be this, predict and forward_propagation? I don't know whether predict should be seperate cell-state, but if i load the model, whether have the same result as not load?

josephjaspers commented 4 years ago

Load is only saving the weights but it also needs to save the inputs/outputs of each layer.

(However it is currently giving different results so its something I have to fix) I added a ticket, so hopefully I will be able to fix this soon. https://github.com/josephjaspers/blackcat_tensors/issues/53

xinsuinizhuan commented 4 years ago

OK. Thank you very much.

josephjaspers commented 4 years ago

So the network is saving variables and weights. It does seem to succeed, but the results are not consistent, still need to investigate why.

josephjaspers commented 4 years ago

This issue has been resolved, The save/loading feature has a ticket here: https://github.com/josephjaspers/blackcat_tensors/issues/53