Closed gaojl19 closed 1 year ago
Hi, sorry for the late response. For the synchronization of the policy, when we create the workers we passed the shared version of policy to the worker which will be used for synchronize the policy version. And indeed due to the async style sampling, there are sometimes the collector is not using the most up-to-date policy to do the sampling but this should be fine, because it will update the policy and SAC is a off-policy method that is able to cope with a really small policy mismatch
@RchalYang Thank you so much for sharing your code ! I was running some experiment based on Soft-Module, and found that the pf used for evaluation in the collector sometimes was not the current pf used in MTSAC (the one that has just been updated during an epoch). This has caused inconsistency between the performance of saved models and their evaluation results. Since the code uses the network in collector to generate evaluation results and update "model_pf_best.pth", but saves the state_dict based on networks in MTSAC; if the pf model used in evaluation(collector) is not the same as the model in MTSAC, the code would record the current pf network in MTSAC rather than the one really used in evaluation.
To reproduce
using python3.7.10, pytorch 1.7.0
class AsyncMultiTaskParallelCollectorUniform(AsyncSingleTaskParallelCollector):
def eval_one_epoch(self): ... statedict = [] # add this line for in range(self.eval_worker_nums): worker_rst = self.eval_shared_que.get() if worker_rst["eval_rewards"] is not None: active_task_counts += 1 eval_rews += worker_rst["eval_rewards"] mean_success_rate += worker_rst["success_rate"] tasks_result.append((worker_rst["task_name"], worker_rst["success_rate"], np.mean(worker_rst["eval_rewards"]))) state_dict.append(worker_rst['pf_state_dict']) # add this line .... dic['mean_success_rate'] = mean_success_rate / active_task_counts dic['state_dict'] = state_dict # add this line return dic
Then train mt10_fixed_shallow, result:
( On epoch 0, everything looks great; but on epoch 1, some of the processes use the new updated pf model for evaluation, but some still use the older verison of pf) Could you please comment on this? Also I am a bit confused about how you keep the pf model in collector synchronized with the model in MTSAC, after calling
start_worker()
, because the pf model should update every epoch but you can only pass the current pf model when creating the processes. It would be very helpful if you can provide some hints on this.Thanks in advance !