CompVis / depth-fm

DepthFM: Fast Monocular Depth Estimation with Flow Matching
MIT License
395 stars 27 forks source link

training problem #9

Closed YongtaoGe closed 7 months ago

YongtaoGe commented 7 months ago

Hi, authors! Thanks for open-source the inference code. I am interested in reimplementing the training process. However, the results look weird. Here is the code snippet I use, it would be appreciated if you could help me check the problem in the code.

  x_1 = depth_latents
  t = torch.rand((rgb_latents.size(0),))
  num_train_timesteps=1000
  x_0 = q_sample(rgb_latents, t=t * num_train_timesteps)
  x_t = (1 - t) * x_0 + t * x_1
  targets = x_1 - x_0
  pred = UNet(xt, t)
  loss = F.mse_loss(pred, targets)
mgui7 commented 7 months ago

Hey, thank you for your interest in this project! One thing is that we use a fixed timestep 400 in q_sample, since this is just responsible for noise augmentation. And also, we pass the image also as a conditioning to the UNet. As a result, the second to the last line should be something like pred = UNet(xt, t, rgb_latents)

pramishp commented 1 month ago

@YongtaoGe , did you manage to run successful training ? If so, can you please share the code ?