adatao / tensorspark

TensorFlow on Spark
297 stars 101 forks source link

some modify to accelerate the train function #13

Open younfor opened 7 years ago

younfor commented 7 years ago

in file "paramservermodel.py" int function def train(self, labels, features):

for i in range(len(self.compute_gradients)):

    #     self.gradients[
    #         i] += self.compute_gradients[i][0].eval(feed_dict=feed)       

because its not necessary to use "for" in this code, replaced by:

grads, test_error_rate = self.session.run([self.compute_gradients,self.error_rate],feed_dict=feed)
self.gradients[:] = [g[0] for g in grads]

this will save several times of gpu train, especial in mnist one iter time from 7ms to 1ms in my computer

illuzen commented 7 years ago

Excellent catch @younfor ! Makes sense, gotta think about that GPU bus...

Submit a PR?