jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
333 stars 54 forks source link

self condition question #27

Closed Wangchentong closed 7 months ago

Wangchentong commented 1 year ago

Hi, i notice the self-condition training strategy is FrameDiff is not same as in RFdifussion:

Self condition training in RFdiffusion in supplementary : Algorithm 3
▷ Train step with self-conditioning x(t+1)
    x(t+1) = ForwardNoise(x(0), t + 1)  ▷ Sample (x(t+1), x(t)) ∼ q(x(t:t+1) | x(0)) 
    x(t) = ReverseStep(x(t+1), x(0))
   Compute self-conditioning input:
    x(0) prev = RF diffusion(x(t+1), 0) 
    x(0)_prev = StopGradient(ˆ x(0) prev.) 
 x(0) = RF diffusion(x(t), x(0)_prev)

So i rewrite framediff code to the same logic in RFdifusion :

# experiment/train_se3_diffusion.py
if self._model_conf.embed.self_condition.version and self._generator.random() > 0.5:
            prev_batch = {}
            # prepare common feature for both step
            prev_batch = ({k:v for k,v in batch.items() if "self_condition_" not in k})
            # prepare diffused feature for prev step
            prev_batch.update({k[len("self_condition_"):]:v for k,v in batch.items() if "self_condition_" in k})
            with torch.no_grad():
                self_condition = self._model(prev_batch)
                if self._generator.random() < 0.5:
                    self_condition = {k:v for k,v in self_condition.items() if  k not in ['edge_embed','node_embed']}

And i fix delta_t_range = [1] to only condition on one step

# pdb_dataloader.py Class PdbDataSet  Function __getitem__(self, idx):
if self.is_training:
          # prev step feature for self-condition 
          t = np.random.uniform(self._data_conf.min_t, 1.0)
          diff_feats_t_prev = self._diffuser.forward_marginal(
              rigids_0=gt_bb_rigid,
              t=t,
              diffuse_mask=None
          )
          diff_feats_t_prev['t'] = t
          # training step feature
          dt = (1.0-self._data_conf.min_t)/self._data_conf.num_t
          dt = [min(t-self._data_conf.min_t,dt*i) for i in self._data_conf.delta_t_range][rng.integers(len(self._data_conf.delta_t_range))]
          rigids_t = self._diffuser.reverse(
              rigid_t=rigid_utils.Rigid.from_tensor_7(diff_feats_t_prev['rigids_t']),
              rot_score=du.move_to_np(diff_feats_t_prev["rot_score"]),
              trans_score=du.move_to_np(diff_feats_t_prev["trans_score"]),
              diffuse_mask=None,
              t=t,
              dt=dt,
              center=False,
              noise_scale=1.0,
          )
          rot_score = self._diffuser.calc_rot_score(
              rigids_t.get_rots()[None,...],
              gt_bb_rigid.get_rots()[None,...],
              t = torch.Tensor([t-dt])
          )[0]
          trans_score = self._diffuser.calc_trans_score(
              rigids_t.get_trans()[None,...],
              gt_bb_rigid.get_trans()[None,...],
              t = torch.Tensor([t-dt]),
          )[0]
          rot_score_scaling,trans_score_scaling = self._diffuser.score_scaling(t-dt)
          diff_feats_t = {
              'rigids_t': rigids_t.to_tensor_7(),
              'trans_score': trans_score,
              'rot_score': rot_score,
              'trans_score_scaling': trans_score_scaling,
              'rot_score_scaling': rot_score_scaling,
              'fixed_mask': fixed_mask,
              **{"self_condition_"+k:v for k,v in diff_feats_t_prev.items()}
          }
          diff_feats_t['t'] = t-dt

But this self-condition traning strategy gives obvious poor performance on secondary structure while training loss is almost same, which makes me rather confused: image image Seems this self-condition strategy shows greater inconsistency between training and inference while from itution, it is is more consistent between training and inference and RFdifusion proves it can work.

jasonkyuyim commented 1 year ago

Hi, are you sure there's no bug in your inference code? I wouldn't expect a noticeable difference between the two strategies. FrameDiff follows the original self-conditioning implementation while RFdiffusion trains self-condioning with one step forward (t+1) which I guess just works better. I honestly have no scientific explanation of why one would work better than the other. Keep in mind these are two very different training set-ups so what works for one will not necessarily work for the other.

I can't tell from your code snippet in which direction you're adding self-conditioning. Are you making sure you're not self-conditioning on t-dt?

Wangchentong commented 1 year ago

i am kind of sure there is no bug(at least not found by myself after multiple check), diff_feats_t_prev is the diffused feature at time t as self_condtion input, diff_rigid_t is input of training created by a reverse step with dt from rigid_t_prev, I also did not expect this case, if you are also interested, you can try it and we can discuss later😃

jasonkyuyim commented 7 months ago

Closing due to inactivity.