Closed QinwenLuo closed 1 year ago
Hi, about the first question, we use commonly used exp-normalize trick to compute the loss and clip the sp_term to avoid overly large gradients. About the second question, please see https://github.com/ryanxhr/IVR/issues/1#issuecomment-1486456418. All those clarifications can be found in the arxiv version of our paper (https://arxiv.org/pdf/2303.15810.pdf).
Thank you for your answer!
elif alg == 'EQL':
sp_term = (q - v) / alpha
sp_term = jnp.minimum(sp_term, 5.0)
max_sp_term = jnp.max(sp_term, axis=0)
max_sp_term = jnp.where(max_sp_term < - 1.0, -1.0, max_sp_term)
max_sp_term = jax.lax.stop_gradient(max_sp_term)
value_loss = (jnp.exp(sp_term - max_sp_term) + jnp.exp(-max_sp_term) * v / alpha).mean()
这儿关于EQL价值函数的更新为什么要将sp_term限制在5.0以下以及在损失函数中乘以一个exp{-max_sp_term}呢?是为了防止指数项太大导致损失过大吗?if alg == 'SQL':
weight = q - v
weight = jnp.maximum(weight, 0)
elif alg == 'EQL':
weight = jnp.exp(10 * (q - v) / alpha)
SQL策略更新的weight不应该是(1+(q-v)/(2*alpha))吗,代码中直接写成q-v是为什么呢? EQL策略更新为什么要在(q - v) / alpha这一项前面乘以10? 这两处更新和论文中推导的公式都不相同,比较困惑,希望有空可以解答