opendilab / DI-engine

OpenDILab Decision AI Engine. The Most Comprehensive Reinforcement Learning Framework B.P.
https://di-engine-docs.readthedocs.io
Apache License 2.0
3.08k stars 373 forks source link

FQF logit computation #775

Closed dmartinezbaselga closed 8 months ago

dmartinezbaselga commented 9 months ago

Hello!

Thank you for this project, it's really complete and modular, which makes it easy to replicate and modify the code. I have a comment regarding the implementation of FQF. As far as I understand the method, line 772 of the file DI-engine/ding/model/common/head.py should be something like:

q_quantiles_width = q_quantiles[..., 1:] - q_quantiles[..., :-1]
logit = (q.permute(2,0,1)*q_quantiles_width).sum(2).permute(1,0)

instead of: logit = q.mean(1)

Thank you in advance for your time, and sorry if I am wrong with the issue and it's a misunderstanding!

PaParaZz1 commented 8 months ago

I think we have already implemented the operation you mentioned line 763-768 (link), i.e.:

        q_quantiles_hats = (q_quantiles[:, 1:] + q_quantiles[:, :-1]).detach() / 2.  # (batch_size, num_quantiles)

        # NOTE(rjy): reparameterize q_quantiles_hats
        q_quantile_net = self.quantile_net(q_quantiles_hats)  # [batch_size, num_quantiles, hidden_size(64)]
        # x.view[batch_size, 1, hidden_size(64)]
        q_x = (x.view(batch_size, 1, -1) * q_quantile_net)  # [batch_size, num_quantiles, hidden_size(64)]

Could you please confirm again? If you have other problems, you can continue to reply in this issue.

dmartinezbaselga commented 8 months ago

Hi,

Thanks for the response. What you are referring is the computation of $\hat{\taui}=\frac{\tau{i+1}-\tau{i}}{2}$ (line 763) and the embedding computation of $\hat{\tau}$ (764-768), which is in the paper in section 3.4: image These refer to the quantile values computation: image The part that I am missing is the $\tau{i+1}-\tau{i}$ that measures the width of the quantile fractions of: image Instead of a mean, it's a weighted mean.

PaParaZz1 commented 8 months ago

Thanks for your feedback, I have checked this part of the original paper and our implementation.