NVlabs / tiny-cuda-nn

Lightning fast C++/CUDA neural network framework
Other
3.78k stars 460 forks source link

Hash encoding second order gradients error "RuntimeError: ... cudaEventRecord(m_event, stream) failed with error an illegal memory access was encountered" #301

Open caseypeat opened 1 year ago

caseypeat commented 1 year ago

Hi,

When working with second order gradients for SDF NeRF work, I get the an illegal memory access error (which can present in a few ways, but here's an example).

RuntimeError: ... cudaEventRecord(m_event, stream) failed with error an illegal memory access was encountered

Below I've added a minimal example of this to demonstrate the issue.

A few things I've noticed that may or may not be useful

My setup as is follows:

Hopefully this should contain enough to reproduce, but feel free to let me know if you would like anymore info, or if I've missed something.

import torch
import torch.nn as nn
import torch.nn.functional as F

import tinycudann as tcnn

class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()

        # Add input layer
        self.layers.append(nn.Linear(input_size, hidden_sizes[0]))
        self.layers.append(nn.ReLU())

        # Add hidden layers
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            self.layers.append(nn.ReLU())

        # Add output layer
        self.layers.append(nn.Linear(hidden_sizes[-1], output_size))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def gradient(self, x):
        x.requires_grad_(True)
        y = self.forward(x)
        d_output = torch.ones_like(y, requires_grad=False, device=y.device)
        gradients = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=d_output,
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        return gradients

class NGP(nn.Module):
    def __init__(self):
        super(NGP, self).__init__()
        self.encoder = tcnn.Encoding(
            n_input_dims=3,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": 16,
                "n_features_per_level": 2,
                "log2_hashmap_size": 19,
                "base_resolution": 16,
                "per_level_scale": 1.3819,
            },
            dtype=torch.float32
        )

    def forward(self, x):
        prefix = x.shape
        x = x.reshape(-1, 3)

        x_hashtable = self.encoder(x)

        y = x_hashtable[..., :3]

        y = y.reshape(*prefix)

        return y

    def gradient(self, x):
        x.requires_grad_(True)
        y = self.forward(x)
        d_output = torch.ones_like(y, requires_grad=False, device=y.device)
        gradients = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=d_output,
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        return gradients

if __name__ == "__main__":

    model = NGP().to("cuda")
    # model = MLP(3, [64, 64], 3).to("cuda")

    optimizer = torch.optim.Adam([{'name': 'model', 'params': list(model.parameters()), 'lr': 1e-3}])

    scaler = torch.cuda.amp.GradScaler()

    batchsize = 16
    n_samples = 8

    for i in range(1000):

        x = torch.rand([batchsize, n_samples, 3]).to("cuda")
        y = torch.rand([batchsize, n_samples, 3]).to("cuda")

        with torch.cuda.amp.autocast():

            optimizer.zero_grad()

            y_ = model(x)

            grad = model.gradient(x)

            loss_eik = torch.mean((torch.linalg.norm(grad, ord=2, dim=-1) - 1.0) ** 2)
            loss_rgb = F.l1_loss(y_, y)

            loss = loss_rgb + loss_eik
            # loss = loss_rgb
            # loss = loss_eik

            print(loss)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update() 
cecabert commented 1 year ago

Running into the same error, any update on this ?