ZhengYinan-AIR / FISOR

[ICLR 2024] The official implementation of "Safe Offline Reinforcement Learning with Feasibility-Guided Diffusion Model"
https://zhengyinan-air.github.io/FISOR/
64 stars 4 forks source link

Why the time-variable t need to be input into FourierFeatures class? #1

Closed ZhaoRunyi closed 5 months ago

ZhaoRunyi commented 6 months ago

Hi! In 'https://github.com/ZhengYinan-AIR/FISOR/blob/master/jaxrl5/networks/diffusion.py', the time variable t were input into FourierFeatures and then the mlp to be a time embediing. I wonder what is it for doing so. It seems that the fourier features is one of the high frequency features, if so, why such a high frequency feature need to be extracted and then be emdded?

Facebear-ljx commented 6 months ago

Hi! Thanks for your interests on our work. This is a good question and we did not investigate this in-depth. Actually, this is a common choice in many popular diffusion policies like IDQL, Decision Diffuser, et al and also many diffusion models for image generation. So, we directly followed this impementation.

From a quick search, I found the fourier feature can improve generalization ability for unseen times, and help the neural network to extract high-frequency features from the scalar time variable, and enjoys other good properties. We would also very happy to hear from your further findings on the advantages of fourier features.

ZhaoRunyi commented 6 months ago

Thanks for your reply! I may check out later about the relation between fourier features and generalization, see if there are any literature furthur explain it. I would share my findings as soon as I finished the research.

Best regards!

zhyang2226 commented 6 months ago

Actually, such timestep embedding method comes from Transformer (Attention is All You Need) and has been widely utilized in diffusion models and Transformer-like network structures. If the scaler timestep is used as input of the network, compared with other high-dimensional inputs, the final output will be difficult to heavily depend on it, which is not what we want. Therefore, encoding the scaler timestep to a vector that has a similar dimension as other inputs can solve this problem.