exalearn / EXARL

Scalable Framework for Reinforcement Learning
Other
10 stars 5 forks source link

Improve calc_target_f and generate_data methods in dqn.py #219

Open tcfuji opened 2 years ago

tcfuji commented 2 years ago

Currently, these methods make some strange implementation decisions that might negatively impact performance.

In generate_data, the data is output as numpy arrays and not pytorch tensors. generate_data calls calc_target_f...

For some reason, calc_target_f is converting the outputs to numpy arrays (look at line 307). This forces generate_data to set the output types to be numpy arrays.

Unless there is a clear reason for this, we should fix this so that all the data are pytorch tensors.

rvinaybharadwaj commented 2 years ago

If you have plans to improve these functions, consider parallelizing them using env_comm

def get_data_parallel(self,workflow):
        env_comm = mpi_settings.env_comm
        global_rank = mpi_settings.global_comm.rank
        #print("Rank %s " % (str(env_comm.rank)))
        #batch_size = workflow.agent.batch_size
        #print("Rank %s has batch_size = %s" % (str(env_comm.rank),str(batch_size)))
        batch_data = []
        #gamma = []
        #device = []
        #target_model = []
        first_offset = []
        chunk_size = []
        batch_data_part = []
        early_stop = False
        #model_type = []
        model_weights = []
        #env_comm = mpi_settings.env_comm
        #chunk_size = int(batch_size / env_comm.size)
        minibatch = []
        my_minibatch = []
        if (env_comm.rank == 0):
            s_gendata_par = MPI.Wtime()
            self.gamma = workflow.agent.gamma
            batch_size = workflow.agent.batch_size
            self.device = workflow.agent.device
            self.target_model = workflow.agent.target_model
            memory_len = len(workflow.agent.memory)
            #model_type = workflow.agent.model_type
            model_weights = self.target_model.get_weights()
            if (memory_len < batch_size):
                early_stop = True
                batch_states = np.zeros(
                    (batch_size, 1, workflow.env.observation_space.shape[0])
                ).astype("float64")
                batch_target = np.zeros((batch_size, workflow.env.action_space.n)).astype(
                    "float64"
                )
                batch_data = batch_states,batch_target
                print("Time taken to generate data(parallely) of batch size %s on %s ranks is %s)" % (str(batch_size),str(env_comm.size),str(MPI.Wtime()-s_gendata_par)))
            else:
                first_offset = int(batch_size/env_comm.size) + (batch_size%env_comm.size)
                chunk_size =  int(batch_size/env_comm.size)
                minibatch = workflow.agent.get_minibatch()

        early_stop = env_comm.bcast(early_stop,root=0)
        #model_type = env_comm.bcast(model_type,root=0)
        model_weights = env_comm.bcast(model_weights,root=0)
        if (early_stop == True):
            #print("Global rank %s Env local Rank %s reached early stop!" % (str(global_rank),str(env_comm.rank)))
            batch_data = env_comm.bcast(batch_data,root=0)
            #print("Global Rank %s Env local Rank %s batch data after early stop is %s " % (str(global_rank),str(env_comm.rank),str(batch_data)))
            return batch_data

        self.gamma = env_comm.bcast(self.gamma,root=0)
        self.device = env_comm.bcast(self.device,root=0)
        #target_model = env_comm.bcast(target_model,root=0)
        first_offset = env_comm.bcast(first_offset,root=0)
        chunk_size = env_comm.bcast(chunk_size,root=0)
        minibatch = env_comm.bcast(minibatch,root=0)
        if (env_comm.rank != 0):
            self.target_model.set_weights(model_weights)

        if(env_comm.rank == 0):
            my_minibatch = minibatch[:first_offset]
        else:
            my_minibatch = minibatch[int(first_offset+((env_comm.rank-1)*chunk_size)):int(first_offset+((env_comm.rank)*chunk_size))]

        #generaate the training data on each processes
        batch_target = list(map(self.calc_target_f_parallel, my_minibatch))
        try:

            batch_states = [np.array(exp[0]).reshape(1, 1, len(exp[0]))[0] for exp in my_minibatch]
            batch_states = np.reshape(batch_states, [len(my_minibatch), 1, len(my_minibatch[0][0])]).astype("float64")
            batch_target = np.reshape(batch_target, [len(my_minibatch), workflow.env.action_space.n]).astype("float64")
        except:
            print("Global Rank: %s Environment Local Rank = %s Length of my mini-batch: %s " % (str(global_rank),str(env_comm.rank),str(len(my_minibatch))))
            print("My mini-batch: %s" % (str(my_minibatch)))
            print("Size of the original minibatch: %s" % (str(len(minibatch))))
            print("Original minibatch: %s" % (str(minibatch)))

        batch_states = env_comm.gather(batch_states,root=0)
        batch_target = env_comm.gather(batch_target,root=0)

        if (env_comm.rank == 0):
            print("Time taken to generate data(parallely) of batch size %s on %s ranks is %s)" % (str(batch_size),str(env_comm.size),str(MPI.Wtime()-s_gendata_par)))

        return batch_states, batch_target

    def calc_target_f_parallel(self,exp):
        state, action, reward, next_state, done = exp
        np_state = np.array(state).reshape(1, 1, len(state))
        np_next_state = np.array(next_state).reshape(1, 1, len(next_state))
        expectedQ = 0
        if not done:
            with tf.device(self.device):
                expectedQ = self.gamma * np.amax(
                    self.target_model.predict(np_next_state)[0]
                )
        target = reward + expectedQ
        with tf.device(self.device):
            target_f = self.target_model.predict(np_state)
        target_f[0][action] = target
        return target_f[0]