Open tcfuji opened 2 years ago
If you have plans to improve these functions, consider parallelizing them using env_comm
def get_data_parallel(self,workflow):
env_comm = mpi_settings.env_comm
global_rank = mpi_settings.global_comm.rank
#print("Rank %s " % (str(env_comm.rank)))
#batch_size = workflow.agent.batch_size
#print("Rank %s has batch_size = %s" % (str(env_comm.rank),str(batch_size)))
batch_data = []
#gamma = []
#device = []
#target_model = []
first_offset = []
chunk_size = []
batch_data_part = []
early_stop = False
#model_type = []
model_weights = []
#env_comm = mpi_settings.env_comm
#chunk_size = int(batch_size / env_comm.size)
minibatch = []
my_minibatch = []
if (env_comm.rank == 0):
s_gendata_par = MPI.Wtime()
self.gamma = workflow.agent.gamma
batch_size = workflow.agent.batch_size
self.device = workflow.agent.device
self.target_model = workflow.agent.target_model
memory_len = len(workflow.agent.memory)
#model_type = workflow.agent.model_type
model_weights = self.target_model.get_weights()
if (memory_len < batch_size):
early_stop = True
batch_states = np.zeros(
(batch_size, 1, workflow.env.observation_space.shape[0])
).astype("float64")
batch_target = np.zeros((batch_size, workflow.env.action_space.n)).astype(
"float64"
)
batch_data = batch_states,batch_target
print("Time taken to generate data(parallely) of batch size %s on %s ranks is %s)" % (str(batch_size),str(env_comm.size),str(MPI.Wtime()-s_gendata_par)))
else:
first_offset = int(batch_size/env_comm.size) + (batch_size%env_comm.size)
chunk_size = int(batch_size/env_comm.size)
minibatch = workflow.agent.get_minibatch()
early_stop = env_comm.bcast(early_stop,root=0)
#model_type = env_comm.bcast(model_type,root=0)
model_weights = env_comm.bcast(model_weights,root=0)
if (early_stop == True):
#print("Global rank %s Env local Rank %s reached early stop!" % (str(global_rank),str(env_comm.rank)))
batch_data = env_comm.bcast(batch_data,root=0)
#print("Global Rank %s Env local Rank %s batch data after early stop is %s " % (str(global_rank),str(env_comm.rank),str(batch_data)))
return batch_data
self.gamma = env_comm.bcast(self.gamma,root=0)
self.device = env_comm.bcast(self.device,root=0)
#target_model = env_comm.bcast(target_model,root=0)
first_offset = env_comm.bcast(first_offset,root=0)
chunk_size = env_comm.bcast(chunk_size,root=0)
minibatch = env_comm.bcast(minibatch,root=0)
if (env_comm.rank != 0):
self.target_model.set_weights(model_weights)
if(env_comm.rank == 0):
my_minibatch = minibatch[:first_offset]
else:
my_minibatch = minibatch[int(first_offset+((env_comm.rank-1)*chunk_size)):int(first_offset+((env_comm.rank)*chunk_size))]
#generaate the training data on each processes
batch_target = list(map(self.calc_target_f_parallel, my_minibatch))
try:
batch_states = [np.array(exp[0]).reshape(1, 1, len(exp[0]))[0] for exp in my_minibatch]
batch_states = np.reshape(batch_states, [len(my_minibatch), 1, len(my_minibatch[0][0])]).astype("float64")
batch_target = np.reshape(batch_target, [len(my_minibatch), workflow.env.action_space.n]).astype("float64")
except:
print("Global Rank: %s Environment Local Rank = %s Length of my mini-batch: %s " % (str(global_rank),str(env_comm.rank),str(len(my_minibatch))))
print("My mini-batch: %s" % (str(my_minibatch)))
print("Size of the original minibatch: %s" % (str(len(minibatch))))
print("Original minibatch: %s" % (str(minibatch)))
batch_states = env_comm.gather(batch_states,root=0)
batch_target = env_comm.gather(batch_target,root=0)
if (env_comm.rank == 0):
print("Time taken to generate data(parallely) of batch size %s on %s ranks is %s)" % (str(batch_size),str(env_comm.size),str(MPI.Wtime()-s_gendata_par)))
return batch_states, batch_target
def calc_target_f_parallel(self,exp):
state, action, reward, next_state, done = exp
np_state = np.array(state).reshape(1, 1, len(state))
np_next_state = np.array(next_state).reshape(1, 1, len(next_state))
expectedQ = 0
if not done:
with tf.device(self.device):
expectedQ = self.gamma * np.amax(
self.target_model.predict(np_next_state)[0]
)
target = reward + expectedQ
with tf.device(self.device):
target_f = self.target_model.predict(np_state)
target_f[0][action] = target
return target_f[0]
Currently, these methods make some strange implementation decisions that might negatively impact performance.
In
generate_data
, the data is output as numpy arrays and not pytorch tensors.generate_data
callscalc_target_f
...For some reason,
calc_target_f
is converting the outputs to numpy arrays (look at line 307). This forcesgenerate_data
to set the output types to be numpy arrays.Unless there is a clear reason for this, we should fix this so that all the data are pytorch tensors.