RchalYang / Soft-Module

Code for "Multi-task Reinforcement Learning with Soft Modularization"
https://rchalyang.github.io/SoftModule
109 stars 19 forks source link

Synchronization problem of policy between collector and MTSAC trainer #4

Closed gaojl19 closed 1 year ago

gaojl19 commented 2 years ago

@RchalYang Thank you so much for sharing your code ! I was running some experiment based on Soft-Module, and found that the pf used for evaluation in the collector sometimes was not the current pf used in MTSAC (the one that has just been updated during an epoch). This has caused inconsistency between the performance of saved models and their evaluation results. Since the code uses the network in collector to generate evaluation results and update "model_pf_best.pth", but saves the state_dict based on networks in MTSAC; if the pf model used in evaluation(collector) is not the same as the model in MTSAC, the code would record the current pf network in MTSAC rather than the one really used in evaluation.

To reproduce

using python3.7.10, pytorch 1.7.0

  1. in torchrl/collector/para/async_mt.py class AsyncMultiTaskParallelCollectorUniform(AsyncSingleTaskParallelCollector):
    
    @staticmethod
    def eval_worker_process(shared_pf, ...):
    ...
    shared_que.put({
                'eval_rewards': eval_rews,
                'success_rate': success / env_info.eval_episodes,
                'task_name': task_name,
                'pf_state_dict': pf.state_dict()   # add this line
            })

def eval_one_epoch(self): ... statedict = [] # add this line for in range(self.eval_worker_nums): worker_rst = self.eval_shared_que.get() if worker_rst["eval_rewards"] is not None: active_task_counts += 1 eval_rews += worker_rst["eval_rewards"] mean_success_rate += worker_rst["success_rate"] tasks_result.append((worker_rst["task_name"], worker_rst["success_rate"], np.mean(worker_rst["eval_rewards"]))) state_dict.append(worker_rst['pf_state_dict']) # add this line .... dic['mean_success_rate'] = mean_success_rate / active_task_counts dic['state_dict'] = state_dict # add this line return dic

(To pass the pf model used during evaluation in each process back)

2. in torchrl/algo/rl_algo.py  function train():
``` python
def train():
     ...
    for reward in eval_infos["eval_rewards"]:
        self.episode_rewards.append(reward)
     # del eval_infos["eval_rewards"]

    # add the following lines
    print("current pf_dict")  
    for name, network in self.snapshot_networks:
        if name == 'pf':
            print(network.state_dict()['base.fc0.weight'])

    print("current collector pf_dict")
    for state_dict in eval_infos['state_dict']:
        print(state_dict['base.fc0.weight'])

Then train mt10_fixed_shallow, result:

2022-01-16 06:35:23,897 MainThread INFO: Finished Pretrain
current pf_dict
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
current collector pf_dict
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
2022-01-16 06:39:08,773 MainThread INFO: EPOCH:0
2022-01-16 06:39:08,773 MainThread INFO: Time Consumed:224.8757426738739s
2022-01-16 06:39:08,773 MainThread INFO: Total Frames:42000s

...
current pf_dict
tensor([[-2.6430e-02,  4.6329e-02,  8.7816e-05,  ...,  9.8133e-03,
         -2.0852e-02, -3.7802e-02],
        [ 2.1514e-02, -3.1466e-02,  4.7950e-02,  ...,  3.4531e-02,
         -3.7469e-03, -4.1485e-03],
        [-4.2676e-02, -6.0403e-02,  3.2930e-02,  ..., -1.7281e-02,
          3.8448e-02,  8.5184e-02],
        ...,
        [-4.0366e-03,  9.0814e-03,  1.9489e-02,  ...,  8.5376e-02,
          1.4941e-02, -3.9423e-02],
        [ 2.8164e-02, -5.6839e-02,  1.3960e-02,  ...,  3.6394e-02,
         -6.9350e-04,  5.4603e-02],
        [ 3.9544e-02, -2.2280e-02,  2.2379e-02,  ..., -4.6273e-03,
          4.4896e-02,  4.9126e-02]])
current collector pf_dict
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[-2.6430e-02,  4.6329e-02,  8.7816e-05,  ...,  9.8133e-03,
         -2.0852e-02, -3.7802e-02],
        [ 2.1514e-02, -3.1466e-02,  4.7950e-02,  ...,  3.4531e-02,
         -3.7469e-03, -4.1485e-03],
        [-4.2676e-02, -6.0403e-02,  3.2930e-02,  ..., -1.7281e-02,
          3.8448e-02,  8.5184e-02],
        ...,
        [-4.0366e-03,  9.0814e-03,  1.9489e-02,  ...,  8.5376e-02,
          1.4941e-02, -3.9423e-02],
        [ 2.8164e-02, -5.6839e-02,  1.3960e-02,  ...,  3.6394e-02,
         -6.9350e-04,  5.4603e-02],
        [ 3.9544e-02, -2.2280e-02,  2.2379e-02,  ..., -4.6273e-03,
          4.4896e-02,  4.9126e-02]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
tensor([[-2.6430e-02,  4.6329e-02,  8.7816e-05,  ...,  9.8133e-03,
         -2.0852e-02, -3.7802e-02],
        [ 2.1514e-02, -3.1466e-02,  4.7950e-02,  ...,  3.4531e-02,
         -3.7469e-03, -4.1485e-03],
        [-4.2676e-02, -6.0403e-02,  3.2930e-02,  ..., -1.7281e-02,
          3.8448e-02,  8.5184e-02],
        ...,
        [-4.0366e-03,  9.0814e-03,  1.9489e-02,  ...,  8.5376e-02,
          1.4941e-02, -3.9423e-02],
        [ 2.8164e-02, -5.6839e-02,  1.3960e-02,  ...,  3.6394e-02,
         -6.9350e-04,  5.4603e-02],
        [ 3.9544e-02, -2.2280e-02,  2.2379e-02,  ..., -4.6273e-03,
          4.4896e-02,  4.9126e-02]])
tensor([[ 0.0050,  0.0437, -0.0044,  ...,  0.0339, -0.0154, -0.0314],
        [ 0.0525, -0.0376,  0.0404,  ...,  0.0587, -0.0008,  0.0007],
        [-0.0723, -0.0500,  0.0435,  ..., -0.0295,  0.0371,  0.0737],
        ...,
        [ 0.0258,  0.0019,  0.0116,  ...,  0.1017,  0.0180, -0.0306],
        [-0.0036, -0.0427,  0.0276,  ...,  0.0337,  0.0001,  0.0388],
        [ 0.0093, -0.0112,  0.0333,  ..., -0.0134,  0.0433,  0.0355]])
2022-01-16 06:43:27,157 MainThread INFO: EPOCH:1
2022-01-16 06:43:27,157 MainThread INFO: Time Consumed:257.96906661987305s
2022-01-16 06:43:27,157 MainThread INFO: Total Frames:44000s

( On epoch 0, everything looks great; but on epoch 1, some of the processes use the new updated pf model for evaluation, but some still use the older verison of pf) Could you please comment on this? Also I am a bit confused about how you keep the pf model in collector synchronized with the model in MTSAC, after calling start_worker(), because the pf model should update every epoch but you can only pass the current pf model when creating the processes. It would be very helpful if you can provide some hints on this.

Thanks in advance !

RchalYang commented 2 years ago

Hi, sorry for the late response. For the synchronization of the policy, when we create the workers we passed the shared version of policy to the worker which will be used for synchronize the policy version. And indeed due to the async style sampling, there are sometimes the collector is not using the most up-to-date policy to do the sampling but this should be fine, because it will update the policy and SAC is a off-policy method that is able to cope with a really small policy mismatch