mikh3x4 / nerf-navigation

Code for the Nerf Navigation Paper. Implements a trajectory optimiser and state estimator which use NeRFs as an environment representation
https://mikh3x4.github.io/nerf-navigation/
MIT License
186 stars 24 forks source link

NotImplementedError on measurement_fn in estimator_helpers.py #16

Closed jfrausto7 closed 1 year ago

jfrausto7 commented 1 year ago

When running simulate.py as described in the README, I'm encountering the following error:

Traceback (most recent call last): File "C:\Users\jaf25\OneDrive\Documents\Github\nerf-navigation\simulate.py", line 350, in <module> simulate(planner_cfg, agent_cfg, filter_cfg, extra_cfg, density_fn, render_fn, get_rays_fn) File "C:\Users\jaf25\OneDrive\Documents\Github\nerf-navigation\simulate.py", line 88, in simulate state_est = filter.estimate_state(gt_img, true_pose, action) File "C:\Users\jaf25\OneDrive\Documents\Github\nerf-navigation\nav\estimator_helpers.py", line 374, in estimate_state xt, success_flag = self.estimate_relative_pose(sensor_img, self.xt.clone().detach(), sig_prop, obs_img_pose=obs_img_pose) File "C:\Users\jaf25\OneDrive\Documents\Github\nerf-navigation\nav\estimator_helpers.py", line 240, in estimate_relative_pose loss.backward() File "C:\Users\jaf25\anaconda3\lib\site-packages\torch\_tensor.py", line 492, in backward torch.autograd.backward( File "C:\Users\jaf25\anaconda3\lib\site-packages\torch\autograd\__init__.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "C:\Users\jaf25\anaconda3\lib\site-packages\torch\autograd\function.py", line 288, in apply return user_fn(self, *args) File "C:\Users\jaf25\anaconda3\lib\site-packages\torch\autograd\function.py", line 404, in backward raise NotImplementedError( NotImplementedError: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD.

For whatever reason, pytorch seems to be complaining about this loss function. Has anyone else come across this before and found a solution? (For reference I'm using pytorch 2.1.0 + CUDA 11.8).