nerfstudio-project / nerfacc

A General NeRF Acceleration Toolbox in PyTorch.
https://www.nerfacc.com/
Other
1.39k stars 115 forks source link

Issue with inplace operation #236

Open ridoughi opened 1 year ago

ridoughi commented 1 year ago

I am trying to re-implement my code using NerfAcc. In my application, I need to estimate some displacements using an MLP. Then, I apply the Rendering on the displaced positions (positions = positions + dpos).

However, I am struggling with the following error: Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [6, 3]], which is output 0 of IndexPutBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

This error occurs in my updated rgb_sigma_fn function, here is the code:

def rgb_sigma_fn(t_starts, t_ends, ray_indices): t_origins = chunk_rays.origins[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 TFrames = timestamps[ray_indices] dpos = radiance_field.compute_displacement(TFrames, positions) positions = positions + dpos rgbs, sigmas = radiance_field(positions, TFrames) return rgbs, sigmas.squeeze(-1)

Where radiance_field.compute_displacement applies just an MLP on encoded time and positions.

def compute_displacement(self, timestamps, x): x = self.position_encoding(x) t = self.timestamps_encoding(timestamps) time_pos = torch.cat([t,x],-1) dpos = self.displace_MLP(time_pos) return dpos

In my previous implementation, I did not get a similar error. I hope that you can help me in solving this issue so that I can apply your impressive code.

Thank you for your time. Best,

liruilong940607 commented 1 year ago

Hi these code does not have in-place operations. You would want to look for something like +=, *= in the entire pipeline. Inplace operation is in general forbidden in any pytorch autodiff.