NVlabs / tiny-cuda-nn

Lightning fast C++/CUDA neural network framework
Other
3.77k stars 458 forks source link

Double backwards support for tcnn.Network #446

Open adamdai opened 5 months ago

adamdai commented 5 months ago

Hello,

I am using a tcnn Hash Encoding + Cutlass MLP to train a neural field. I would like to access second derivatives to implement a smoothing loss during training, as well as obtain gradients for performing trajectory optimization over the neural field.

I noticed that double backwards support has been implemented for tcnn.Encoding here. However, the example uses a torch MLP which is much slower. When trying to obtain 2nd derivatives with a tcnn.Network MLP instead, a backward_backward_input_impl: not implemented error is thrown.

Here is a minimal example for testing:

import torch
import tinycudann as tcnn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class NeuralField(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoding = tcnn.Encoding(
            n_input_dims=2,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": 8,
                "n_features_per_level": 8,
                "log2_hashmap_size": 19,
                "base_resolution": 16,
                "per_level_scale": 1.2599210739135742,
                "interpolation": "Smoothstep"
            },
        )
        tot_out_dims_2d = self.encoding.n_output_dims

        self.mlp = tcnn.Network(
            n_input_dims=tot_out_dims_2d,
            n_output_dims=1,
            network_config={
                "otype": "CutlassMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 256,
                "n_hidden_layers": 3,
            },
        )

    def forward(self, x):
        x = self.encoding(x)
        x = self.mlp(x)
        return x

if __name__ == "__main__":
    nf = NeuralField()
    nf.to(device)

    x = torch.rand(10, 2, requires_grad=True).to(device)
    y = nf(x)

    # 2 methods of obtaining 2nd derivative - both give the same error
    method = 0

    if method == 0:
        grad = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
        grad_2 = torch.autograd.grad(grad.sum(), x)[0]
    else:
        grad = torch.autograd.grad(y, x, torch.ones_like(y, device=x.device), 
                                create_graph=True, retain_graph=True, only_inputs=True)[0]
        grad_2 = torch.autograd.grad(grad, x, torch.ones_like(grad, device=x.device), 
                                    create_graph=False, retain_graph=False, only_inputs=True)[0]

The full error trace is:

Traceback (most recent call last):
  File "tcnn_double_backward.py", line 54, in <module>
    grad_2 = torch.autograd.grad(grad.sum(), x)[0]
  File "/home/navlab/anaconda3/envs/nemo/lib/python3.8/site-packages/torch/autograd/__init__.py", line 394, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/navlab/anaconda3/envs/nemo/lib/python3.8/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/navlab/anaconda3/envs/nemo/lib/python3.8/site-packages/tinycudann/modules.py", line 145, in backward
    doutput_grad, params_grad, input_grad = ctx.ctx_fwd.native_tcnn_module.bwd_bwd_input(
RuntimeError: DifferentiableObject::backward_backward_input_impl: not implemented error

Are there any plans to implement double backward for tcnn.Network? It would be greatly helpful for my project.