Kinyugo / consistency_models

A mini-library for training consistency models.
https://arxiv.org/abs/2303.01469
MIT License
189 stars 20 forks source link

SamplingAndEditing: mask transform should also apply to x? #14

Open sentient-codebot opened 2 months ago

sentient-codebot commented 2 months ago

In the class ConsistencySamplingAndEditing, the __mask_transform method, which I think corresponds to Algorithm 4 in the paper, essentially applies the A matrix as in the paper. In that case, the transform_fn should also apply on x?

def __mask_transform(
        self,
        x: Tensor,
        y: Tensor,
        mask: Tensor,
        transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
        inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x,
    ) -> Tensor:
        return inverse_transform_fn(transform_fn(y) * (1.0 - mask) + x * mask)
Kinyugo commented 2 months ago

You are right. I'll update the code as I also include some other fixes.