buoyancy99 / diffusion-forcing

code for "Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion"
Other
626 stars 30 forks source link

Questions on the Configuration of Maze Planning #18

Closed hw-du closed 2 months ago

hw-du commented 2 months ago

Hi, thanks for the great work!

I am a little confused about the configuration regarding the mean/std of the observation/action/reward. I see they are used to normalize the observation/action/reward. But how to set these mean/std values? Are they crucial to the success of planning with diffusion models?

Thanks!

Configurations from maze2d_large.yaml: observation_mean: [3.7296331e+00, 5.3047247e+00, 4.7289828e-05, 2.9168357e-05] observation_std: [1.8070312, 2.5687592, 2.4368455, 2.6493697] action_mean: [0.00100675, 0.00078245] action_std: [0.72493, 0.7394606] reward_mean: 0.0089225 reward_std: 0.09403666

buoyancy99 commented 2 months ago

Yes, they are very important to the maze planning experiment. To set them, either modify the yaml files directly or follow my example commands from README.md to override the values from command line.

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] load=outputs/maze2d_medium_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=3 +name=maze2d_medium_x_sampling
hw-du commented 2 months ago

Thanks for the answer! But I just want to know how to compute the exact value of these mean/std parameters.

buoyancy99 commented 2 months ago

Basically you compute it from dataset - and scale it with some constant

import gym
import numpy as np
import d4rl
import random
import matplotlib.pyplot as plt

env = gym.make("your env name here")
dataset = env.get_dataset()
print("episode len", env._max_episode_steps)

o = dataset["observations"]
x = o[:, 0]
y = o[:, 1]
r = dataset["rewards"]
a = dataset["actions"]
print("o_mean, o_std", o.mean(axis=0), o.std(axis=0))
print("a_min, a_max, a_mean, a_std", a.min(), a.max(), a.mean(), a.std())
print("r_min, r_max, r_mean, r_std", r.min(), r.max(), r.mean(), r.std())

Will give you the numbers in yaml file. However, this is all assuming that your data follows normal distribution, which is not the case the majority of the time, so you multiply the numbers by some constant like 2x, 4x etc as tuning

hw-du commented 2 months ago

Thanks! This is really helpful.