FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
414 stars 72 forks source link

Probable issue with implementation of VDN #93

Closed ElliotXinqiWang closed 4 months ago

ElliotXinqiWang commented 5 months ago

The TD loss in the implementation of baselines/Qlearning/vdn.py looks confusing. When done, we should take reward as return, not the value of next timestep, which should be the return of next initial state.

def _td_lambda_target(ret, values):
        reward, done, target_qs = values
        ret = jnp.where(
            done,
            target_qs,
            ret*config['TD_LAMBDA']*config['GAMMA']
            + reward
            + (1-config['TD_LAMBDA'])*config['GAMMA']*(1-done)*target_qs
        )
        return ret, ret
mttga commented 5 months ago

hi, this follows the pymarl2 implementation: https://github.com/hijkzzz/pymarl2/blob/1254df6f0cb9d1fd2b7c60566a5e23eff8f0528e/src/utils/rl_utils.py#L6

do you find any differences?