anucvml / ddn

Deep Declarative Networks
MIT License
236 stars 37 forks source link

batched operation in pytorch pnp node #16

Open flandrewries opened 2 years ago

flandrewries commented 2 years ago

hello, I'm using the release code of your pnp_node.py, as my inputs are batched points, each one with different pose , so i would like to use this operation:

    # # Alternatively, disentangle batch element optimization:
    # for i in range(p2d.size(0)):
    #     Ki = K[i:(i+1),...] if K is not None else None
    #     theta[i, :] = self._run_optimization(p2d[i:(i+1),...],
    #         p3d[i:(i+1),...], w[i:(i+1),...], Ki, y=theta[i:(i+1),...])

however, i find the upper level function dose not update the w value. I printed the theta.grad to check whether the gradient is calculated, and find that theta[i:(i+1),...].grad is None. maybe when the optimization is done, the slice or copy ops will not copy the grad value. Is there any way for solving this problem?

Very appreciate for your advice.

dylan-campbell commented 2 years ago

Hi @flandrewries, the solve function already estimates a different theta (pose) for each batch element. The commented code is an inefficient debugging aid which can be useful if you don't want to run optimisation all the way to convergence for all batch elements. Solve itself is not part of the computation graph so you won't get any gradients there, but you should get the correct gradient values for the inputs to your DeclarativeLayer. Check out the PnP example here for details.