Open sentient-codebot opened 2 months ago
In the class ConsistencySamplingAndEditing, the __mask_transform method, which I think corresponds to Algorithm 4 in the paper, essentially applies the A matrix as in the paper. In that case, the transform_fn should also apply on x?
__mask_transform
transform_fn
def __mask_transform( self, x: Tensor, y: Tensor, mask: Tensor, transform_fn: Callable[[Tensor], Tensor] = lambda x: x, inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x, ) -> Tensor: return inverse_transform_fn(transform_fn(y) * (1.0 - mask) + x * mask)
You are right. I'll update the code as I also include some other fixes.
In the class ConsistencySamplingAndEditing, the
__mask_transform
method, which I think corresponds to Algorithm 4 in the paper, essentially applies the A matrix as in the paper. In that case, thetransform_fn
should also apply on x?