Eclectic-Sheep / sheeprl

Distributed Reinforcement Learning accelerated by Lightning Fabric
https://eclecticsheep.ai
Apache License 2.0
311 stars 32 forks source link

weight synchronization between player and main models in DreamerV1 #323

Open Andy-Mielk opened 1 week ago

Andy-Mielk commented 1 week ago

Hello, I'm studying the DreamerV1 implementation in your repository, and I have a question about how the weights are synchronized between the player and the main models (world_model, actor). In the build_agent function, I see that the player is initialized with deep copies of the main model components:

    player = PlayerDV1(
        copy.deepcopy(world_model.encoder), 
        copy.deepcopy(world_model.rssm.recurrent_model),
        copy.deepcopy(world_model.rssm.representation_model),
        copy.deepcopy(actor),  
        actions_dim,
        cfg.env.num_envs,
        cfg.algo.world_model.stochastic_size,
        cfg.algo.world_model.recurrent_model.recurrent_state_size,
        fabric_player.device,
    )

Then, the weights are copied from the main models to the player:

    for agent_p, p in zip(world_model.encoder.parameters(), player.encoder.parameters()):
        p.data = agent_p.data
    for agent_p, p in zip(world_model.rssm.recurrent_model.parameters(), player.recurrent_model.parameters()):
        p.data = agent_p.data
    for agent_p, p in zip(world_model.rssm.representation_model.parameters(), player.representation_model.parameters()):
        p.data = agent_p.data

If I understand correctly, the player is responsible for the environment interaction, so its weight must be same as the world model and actor along with their update. However, this appears to be a one-time copy rather than a continuous synchronization. I'm wondering: How are the player's weights kept in sync with the main models during training?

Thank you for your time!

michele-milesi commented 1 week ago

Hi @Andy-Mielk, thanks for your question. The weights of the player are the same of the weights in the world model. The agent_p.data is a tensor, so, every time you update the world model weights, the player will read the updated weights of the world model: they share the tensors of the weight, and they point to the same object.