Open younfor opened 7 years ago
in file "paramservermodel.py" int function def train(self, labels, features):
# self.gradients[ # i] += self.compute_gradients[i][0].eval(feed_dict=feed)
because its not necessary to use "for" in this code, replaced by:
grads, test_error_rate = self.session.run([self.compute_gradients,self.error_rate],feed_dict=feed) self.gradients[:] = [g[0] for g in grads]
this will save several times of gpu train, especial in mnist one iter time from 7ms to 1ms in my computer
Excellent catch @younfor ! Makes sense, gotta think about that GPU bus...
Submit a PR?
in file "paramservermodel.py" int function def train(self, labels, features):
for i in range(len(self.compute_gradients)):
because its not necessary to use "for" in this code, replaced by:
grads, test_error_rate = self.session.run([self.compute_gradients,self.error_rate],feed_dict=feed)
self.gradients[:] = [g[0] for g in grads]
this will save several times of gpu train, especial in mnist one iter time from 7ms to 1ms in my computer