Closed Jackaljkdan closed 6 years ago
I believe this is happening because you are using the default graph in each thread. Try creating a graph in the thread and using that instead of a default graph. On a side note - 64 threads sounds excessive. You are unlikely to get additional throughput at that sort of value.
If I create a graph for each thread, simply by using with tf.Graph().as_default():
instead of the previous with stamentent, the following exception is raised, for any number of threads:
ValueError: Tensor("Placeholder:0", shape=(1, 1), dtype=float32) must be from the same graph as Tensor("lstm_2/Variable:0", shape=(1, 1), dtype=float32_ref).
on line net.reset_states()
in thread_fn
.
I think this may be due to the fact that the networks are created outside of the threads, using a different graph. However if I try creating them inside threads, each in a different graph, i.e.
def thread_fn(index, global_net):
"""
:param index:
:type global_net: Model
"""
print("thread-%s" % index)
local_graph = tf.Graph()
# get global network weights
with global_graph.as_default():
gw = global_net.get_weights()
with local_graph.as_default():
net = make_rnet()
net.reset_states()
# sync weigths with global network
net.set_weights(gw)
...
The following exception is raised, again on line net.reset_states()
and for any number of threads:
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, 1), dtype=float32) is not an element of this graph.
Also, I realize 64 is an excessive number, that value is only needed to reproduce the issue I experience in my real code, where I have 16 theads, each with 6 networks, of which one is selected based on some condition and used to make a prediction at every iteration.
The exact use case is confusing me a little. I'm not understanding the purpose behind setting weights on a global network from within the threads instantaneously. Are you trying to train one global network and then have concurrent threads predicting with the weights based off the global network? Or are you trying to train a global network and then modify it's weights from within the threads? I think there may be an easier to achieve what you're after but hard to say without knowing this.
I modified your code - I may not have carried over the functionality you were after regarding this global network but this is training one global model and then using it's values to predict in threads.
import threading
import numpy as np
from keras.layers import Input, LSTM
from keras.models import Model, model_from_json
import tensorflow as tf
n_threads = 64
def make_rnet():
inp = Input(batch_shape=(1,1,5))
out = LSTM(1)(inp)
return Model(
inputs=inp,
outputs=out
)
def thread_fn(index, architecture, weights):
"""
:param index:
:type global_net: Model
"""
print("thread-%s" % index)
with tf.Session(graph = tf.Graph()) as sess:
# Build model.
net = model_from_json(architecture)
net.set_weights(weights)
net.compile(optimizer='rmsprop', loss='mse')
in_shape = [int(d) for d in net.input.shape]
out_shape = [int(d) for d in net.output.shape]
# Test prediction.
predictions = net.predict(np.ones(shape=in_shape))
print(predictions)
# Test fit on random data.
x = np.random.random(size=in_shape)
y = np.ones(shape=out_shape)
net.fit(x,y, verbose=0, batch_size=1)
# Train a global network.
global_net = make_rnet()
global_net.compile(optimizer='rmsprop', loss='mse')
x = np.random.random(
size=[int(d) for d in global_net.input.shape]
)
y = np.ones(
shape=[int(d) for d in global_net.output.shape]
)
global_net.fit(x, y)
# Get the network in a portable format.
architecture = global_net.to_json()
weights = global_net.get_weights()
# Generate threads.
threads = [
threading.Thread(target=thread_fn, args=(i, architecture, weights)) for i in range(n_threads)
]
print("starting %s threads..." % n_threads)
for t in threads:
t.start()
for t in threads:
t.join()
print("threads terminated.")
Alternatively, if you do need to get data from the threads back out perhaps a queue implementation would work?
import json
import threading
import queue
import numpy as np
from keras.layers import Input, LSTM
from keras.models import Model, model_from_json
import tensorflow as tf
queue = queue.Queue()
n_threads = 32
def make_rnet():
inp = Input(batch_shape=(1,1,5))
out = LSTM(1)(inp)
return Model(
inputs=inp,
outputs=out
)
def thread_fn(index, queue, architecture, weights):
"""
:param index:
:type global_net: Model
"""
print("thread-%s" % index)
with tf.Session(graph = tf.Graph()) as sess:
# Build model.
net = model_from_json(architecture)
net.set_weights(weights)
net.compile(optimizer='rmsprop', loss='mse')
in_shape = [int(d) for d in net.input.shape]
out_shape = [int(d) for d in net.output.shape]
# Test prediction.
predictions = net.predict(np.ones(shape=in_shape))
print(predictions)
# Test fit on random data.
x = np.random.random(size=in_shape)
y = np.ones(shape=out_shape)
net.fit(x,y, verbose=0, batch_size=1)
# Enqueue new weights.
queue.put({
'weights': net.get_weights(),
'index': index
})
# Train a global network.
global_net = make_rnet()
global_net.compile(optimizer='rmsprop', loss='mse')
x = np.random.random(
size=[int(d) for d in global_net.input.shape]
)
y = np.ones(
shape=[int(d) for d in global_net.output.shape]
)
global_net.fit(x, y)
# Get the network in a portable format.
architecture = global_net.to_json()
weights = global_net.get_weights()
# Generate threads.
threads = [
threading.Thread(target=thread_fn, args=(i, queue, architecture, weights)) for i in range(n_threads)
]
print("starting %s threads..." % n_threads)
for t in threads:
t.start()
while True:
# Setting weights from thread.
item = queue.get()
print("Callback from thread-%s" % item['index'])
global_net.set_weights(item['weights'])
queue.task_done()
Thank you for your help. I'm trying to implement A3C reinforcement learning (https://arxiv.org/abs/1602.01783) in keras, the pseudo code in the paper is in algorithm S3. This means each thread should perform a loop in which first it synchronizes its weigths with the global ones, uses them to build an episode (using .predict() in keras), then train on it and apply the gradients to the global network, then finally begin a new iteration. The algorithm is intended not to be thread safe in order to maximize the throughput.
I like your queue implementation, I was also thinking I don't really need a global network, only global weights to syncronize the threads to.
I'm not sure I understand why keras is not throwing any exception on your thread implementation: I see you create a new session with a new graph in each of them, however the session doesn't appear to be used in the thread code. Still, if I remove that "with" statement exceptions are raised regarding tensors not being element of the correct graph. Is keras using it automatically?
Hi, okay thank you for the context.
Yes agreed I don't think you need a global network then; just to keep the main weights in the global thread.
The reason why it is not erroring within the thread is because of this line: https://github.com/keras-team/keras/blob/master/keras/backend/tensorflow_backend.py#L166
From tensorflow's documentation:
Returns the default session for the current thread. The returned Session will be the innermost session on which a Session or Session.as_default() context has been entered. NOTE: The default session is a property of the current thread. If you create a new thread, and wish to use the default session in that thread, you must explicitly add a with sess.as_default(): in that thread's function.
Without the with
line, tf.get_default_session()
returns None
.
Thank you!
Good! Thanks! @DomHudson 2397
I have a problem with using Keras models in threads. I have been able to reproduce the issue in a MWE that I report below. Here and in the real code, I create the networks outside the threads, then inside of them i use a network to make predictions and fitting.
To make it work with as few as 2 threads, I hade to use the workarounds described in: https://github.com/keras-team/keras/issues/5896 and https://github.com/keras-team/keras/issues/6124 (i.e. they amount to using
with graph.as_default():
andnet._make_train_function()
)The program executes without any problem with 32 threads, however when that number is increased to 64 exceptions are raised (see below).
Some of the exceptions raised with 64 threads are:
In the real code much less threads (8) are needed for this behaviour to occur, maybe because each thread uses many networks (6) and they have way more parameters (~3 millions). Why is this happening?