NVlabs / nvdiffrec

Official code for the CVPR 2022 (oral) paper "Extracting Triangular 3D Models, Materials, and Lighting From Images".
Other
2.09k stars 222 forks source link

Would it be possible to modify this to also use CPU memory? #15

Open sirisian opened 2 years ago

sirisian commented 2 years ago

I looked at issue #2 and I'm curious if you could modify the project to use CPU memory by swapping in and out images and other data structures as needed? (Not sure if https://github.com/IBM/pytorch-large-model-support could be used for the tensor part). I noticed the limitations section in the paper mentioned the memory consumption already, so if this isn't easy then feel free to close this.

Part of this is I'm interested in seeing what happens with larger and more images. (This assumes training duration isn't a high priority and model quality is the goal. I have a 3090 and a lot of DDR5 memory for reference).

jmunkberg commented 2 years ago

Hello,

JHnvidia commented 2 years ago

Hi @sirisian,

I don't think we'll do significant memory optimizations to the code in the short term. We want it to stay true to the paper version, and as researchers we have limited time to dedicate to this release.

The package looks fairly promising, but I'm not sure how it manages GPU/CPU transitions. As Jacob mentioned above, we rely on quite a few CUDA kernels which have no CPU fallback. If the LMS library always run kernels on the CUDA device and temporarily swap back to CPU memory it could work, but if it requires both CPU and CUDA kernels it will be very hard.

Another optimization you might want to try is to change the batching code. Currently we batch by rendering to [N, H, W, C] tensors, where N is the batch size. Temporary results need to be stored at the same resolution, so memory consumption grows linearly with N. The main benefit of batching is gradient averaging, so it would be possible to instead run a loop over N x [1, H, W, C] forward + backward passes and just average the final gradients.

sirisian commented 2 years ago

Thanks for all the ideas.

More images should work out of the box.

Larger images than, say 2k x 2k, are trickier to support

You can work around this by e.g., rendering random crops from larger images in each training iteration

Wondering if there's a more naive way to do something like this without modifying the code (or very little). Like if one could take a single camera image 9Kx7K and treat it like 9x7=63 cameras each with a resolution of 1Kx1K pixels and throw away masked out cameras. (So only the fake cameras not masked out are included). The frustums wouldn't follow normal camera intrinsic definitions though, so I imagine this wouldn't be viable without relating them back to a single original camera frustum and instrinsics.

jmunkberg commented 2 years ago

Yes, you need to adjust the camera frustum. We have a function util.perspective_offcenter that does this. We tested random cropping at some point but removed it for the public release.

In dataset_llff.py or dataset_nerf.py , you can modify the end of the _parse_frame(self, idx) function with something like this to enable random cropping (the code below is untested, but you get the idea):

...
CROP_SIZE = 256
height = img.shape[0]
width  = img.shape[1]
xstart = np.random.randint(0, width-CROP_SIZE)
ystart = np.random.randint(0, height-CROP_SIZE)
img = img[ystart:ystart+CROP_SIZE, xstart:xstart+CROP_SIZE, :]
_rx = xstart / width
_ry = ystart / height
# Override projection matrix and mvp
proj_mtx = util.perspective_offcenter(fovy, CROP_SIZE/width, _rx, _ry, width / height, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
mvp      = proj_mtx @ mv
mexicantexan commented 2 years ago

@jmunkberg adding this code in seems to break an assertion in render/render.py line 215 assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) Any suggestions on working around this?

For context, I changed the end of _parse_frame in dataset_llff.py to:

........
        if self.FLAGS.crop_dataset is True:
            print('cropping dataset')
            if self.FLAGS.CROP_SIZE is None:
                CROP_SIZE = 256
            else:
                CROP_SIZE = self.FLAGS.CROP_SIZE
            height = img.shape[0]
            width = img.shape[1]
            xstart = np.random.randint(0, width - CROP_SIZE)
            ystart = np.random.randint(0, height - CROP_SIZE)
            img = img[ystart:ystart + CROP_SIZE, xstart:xstart + CROP_SIZE, :]
            _rx = xstart / width
            _ry = ystart / height

            # Override projection matrix and mvp
            proj = util.perspective_offcenter(self.fovy[idx, ...], CROP_SIZE / width, _rx, _ry, width / height,
                                              self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])

        else:
            # Setup transforms
            proj = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0],
                                    self.FLAGS.cam_near_far[1])
        mv = torch.linalg.inv(self.imvs[idx, ...])

        campos = torch.linalg.inv(mv)[:3, 3]
        mvp = proj @ mv

        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...]

Where I pass the customs flags self.FLAGS.crop_dataset and self.FLAGS.CROP_SIZE in with the rest of the flags.

hzhshok commented 1 year ago

Anyway, just I assume that self.fovy[idx, ...] should not indexed with idx(frame id) if your images are RGBA type.