zhaoyl18 / SEIKO

SEIKO is a novel reinforcement learning method to efficiently fine-tune diffusion models in an online setting. Our methods outperform all baselines (PPO, classifier-based guidance, direct reward backpropagation) for fine-tuning Stable Diffusion.
https://arxiv.org/abs/2402.16359
MIT License
14 stars 0 forks source link

Question about the implementation of bootstrap reward #3

Closed Guo-Stone closed 3 weeks ago

Guo-Stone commented 1 month ago

Dear authors, I would like to ask a question about the code implementation of bootstrap reward. In my opinion, bootstrap method involves in training several models using different dataset. And I think that the reward $r$ and its uncertainty $g$ should be the average and deviation of the outputs of all the bootstrapped models. But in your following code, why do you only choose the best one as the reward $r$ and not take uncertainty $g$ into consideration?

optimisticrewards, = torch.max(stacked_outputs, dim=1, keepdim=True)

zhaoyl18 commented 3 weeks ago

Hi, Thanks for your interest in our work. Indeed, in statistical theory, bootstrapping refers to resampling data. In RL, however, bootstrapping is a way to encourage exploration. See https://rail.eecs.berkeley.edu/deeprlcourse/deeprlcourse/static/slides/lec-13.pdf

Regarding implementing bootstrapping in RL, there are several usual considerations.

  1. Training N big neural networks is expensive. To save computation in training and inference, a useful practice is constructing a shared backbone with multiple heads. So that most of the architecture is shared across bootstrapped models (see Page 35 of the shared slides). In our implementation, we found that 4 predictor heads will work.
  2. Practically, researchers found bootstrapping is very effective for RL exploration. It is found training two different neural networks and taking min/max of them is already nice.
  3. Regarding inference, there are several ways. The most principled way is closely related to Tompson sampling (see Page 36 of the slides.). However, picking prior distribution is highly manual. To some extent, taking an average of all outputs is related to this. In the practice of RL, the following are all effective yet simple approaches: taking min/max or softmin/softmax of the outputs.
  4. For our work, it is possible to use (1) softmax or (2) max or (3) average. Experimentally, we found that (1) involved choosing a hard temperature parameter, which can be avoided if using (2). Our implementation is based on this.