seohongpark / HIQL

HIQL: Offline Goal-Conditioned RL with Latent States as Actions (NeurIPS 2023)
MIT License
71 stars 6 forks source link

Several questions about function compute_value_loss(). #6

Closed Looomo closed 2 weeks ago

Looomo commented 2 weeks ago

Dear Authors, Nice work of HIQL! I have been trying to run your code recently, but I encountered some difficulties in understanding it. Could you please explain the rationale behind these design choices? That would be a great help to me!

  1. In Equation 4 of your paper, you mentioned the objective of V(s,g) should be $L2^t ( r + V{target}(s', g) - V(s,g) )$. However, this seems to differ from the implementation in the code of compute_value_loss. In the code of compute_value_loss, the input of indicator of Equation 4 is adv:

https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/src/agents/hiql.py#L103

https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/src/agents/hiql.py#L91-L97

Since $adv = r + V{target}(s', g) - V{target}(s,g) $, this seems differs form Eq.4.

  1. Additionally, I have some questions regarding your operations on the function values of V(s,g). Specifically, in line 96 https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/src/agents/hiql.py#L96 V(s,g) seems is calculated as the average of V(s) and V(g). But on line 92 https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/src/agents/hiql.py#L92 V(s',g) is calculated as the minimum value of V(s') and V(g). Moreover, on lines 99-101, value_losses are calculated from V(s) and V(g) separately.

Could you elaborate on these designs? Thanks~!

seohongpark commented 2 weeks ago

Hi Looomo,

Thanks for the questions. These are small implementation details that we found to improve the performance of GC-IQL. Regarding the first point, we found that decoupling advantages and losses (similarly to how Double DQN decouples argmax and max) slightly improves performance in general. Regarding the second point, we can generally use either (v1+v2)/2 or min(v1, v2) to compute a scalar value from two ensemble value functions. We tested several variants, and found the current variant to perform the best in general. Though as far as I remember, these minor details didn't affect performance that much (but sometimes one is more stable/slightly better than others).