Armandpl / dreamerv3

DreamerV3 + gSDE, using pytorch, on a real robot
1 stars 0 forks source link

refactor and setup training on atari #10

Closed Armandpl closed 5 months ago

Armandpl commented 6 months ago
Armandpl commented 6 months ago

I am confused about how to align the advantage with the actions. Since we predict the reward and compute the lambda value and the advantage from the world model state ht+zt and since ht+zt is the result of the action at-1 I think the advantage at t should be used to push the log prob of the action that is responsible for this advantage, and that's at-1. That's how the code should look like:

    policy = actor(
        sg(hts[:, :-2]),
        sg(zts[:, :-2]),
    )
    logpi = policy.log_prob(sg(ats[:, :-2]).squeeze(-1))
    actor_loss = -logpi * sg(advantage[:, 1:].squeeze(-1))  # offset the advantage by one
    actor_entropy = policy.entropy()
    actor_loss -= ACTOR_ENTROPY * actor_entropy
    actor_loss = actor_loss * sg(traj_weight[:, :-2])
    actor_loss = actor_loss.mean()

But doing this the training collapses (wandb run): W B Chart 3_6_2024, 3_16_44 PM However, if I use the advantage at t to push the log prob of at:

    policy = actor(
        sg(hts[:, :-1]),
        sg(zts[:, :-1]),
    )
    logpi = policy.log_prob(sg(ats[:, :-1]).squeeze(-1))
    actor_loss = -logpi * sg(advantage.squeeze(-1))
    actor_entropy = policy.entropy()
    actor_loss -= ACTOR_ENTROPY * actor_entropy
    actor_loss = actor_loss * sg(traj_weight[:, :-1])
    actor_loss = actor_loss.mean()

It now works (wandb run): W B Chart 3_6_2024, 3_17_45 PM (1)

Armandpl commented 6 months ago
Armandpl commented 6 months ago

ok so I trained on pong overnight, for 150k step with a training ratio of 1024. It learns some stuff but it seems more unstable than the official scores and reaches a way lower return (-10 instead of 20). Why is that?

todo: