Thanks for sharing the code. There is a possible bug for your check:
if start_t is None:
model_input = torch.cat([inputs[:, :6], image], dim=1)
timesteps = self.scheduler.timesteps
else:
model_input = inputs
timesteps = self.scheduler.timesteps[-start_t:]
# need your check
# since we do not want to use a random noise for start_t being
# a non-zero step (for example: the coarse-to-fine pattern)
image = inputs[:,-2:,:,:] # use coarse flow you predicted already
Hello,
Thanks for sharing the code. There is a possible bug for your check: