lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
194 stars 14 forks source link

Inconsistent interpretation of model output between self-conditioning step and prediction step #15

Open jsternabsci opened 1 year ago

jsternabsci commented 1 year ago

I think there is a bug in the interface self-conditioning in rin_pytorch.py.

The model output is interpreted differently during the self-conditioning stage compared to the prediction stage.

Currently we have (pseudocode):

self_cond = x0_to_target_modification(model_output)           # Treat model prediction as x0 convert it to x0, eps, or v
pred = self.model(..., self_cond, ...)                        # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

In the current implementation, the interface prediction is interpreted as x0 during self-conditioning, but as the target (x0, eps, or v) at the prediction step.

I see two ways that we could do interface self-conditioning that would be consistent.

We could either:

In contrast to the current implementation, in my two proposals, the interpretation of the interface prediction is the same between the self-conditioning step and the prediction step. Would you agree that there is inconsistency here and that either of these proposals solves it?

Here is the current code:

if random() < self.train_prob_self_cond:
    with torch.no_grad():
        model_output, self_latents = self.model(noised_img, times, return_latents = True)
        self_latents = self_latents.detach()

        if self.objective == 'x0':
            self_cond = model_output

        elif self.objective == 'eps':
            self_cond = safe_div(noised_img - sigma * model_output, alpha)

        elif self.objective == 'v':
            self_cond = alpha * noised_img - sigma * model_output

        self_cond.clamp_(-1., 1.)
        self_cond = self_cond.detach()

# predict and take gradient step

pred = self.model(noised_img, times, self_cond, self_latents)

...

loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
lucidrains commented 1 year ago

@jsternabsci i have only seen self conditioning done with the predicted x0 (correct me if i'm wrong)

there's nothing inconsistent as during inference, self conditioning is also done with the predicted x0

however, i get what you are saying. i could offer both options, if you are running experiments to see which way is better?