pytorch / extension-cpp

C++ extensions in PyTorch
1.02k stars 214 forks source link

cuda extension gradient does not match autograd results #60

Closed CoinCheung closed 4 years ago

CoinCheung commented 4 years ago

Hi,

I am not sure whether this is a bug or problem with my implementation. I am working on ubuntu1604 docker container, and ananconda python3.6.9 and pytorch1.3.1. Since I am not sure where the problem exactly exists, my sample code is a bit long, though I have tried my best to remove useless logic. The code is a softmax cross entropy loss, like this:

mport fun_cpp
class CrossEntropyFunctionV2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, ignore_index):
        losses = fun_cpp.lsr_forward(logits, labels, ignore_index)

        ctx.variables = logits, labels, ignore_index
        return losses

    @staticmethod
    def backward(ctx, grad_output):
        logits, labels, ignore_index = ctx.variables

        grad = fun_cpp.lsr_backward(grad_output, logits, labels, ignore_index)
        print('grad2: ', grad[0, :, 0, 0])
        return grad, None, None

class CrossEntropyLossV2(nn.Module):

    def __init__(self, reduction='mean', ignore_index=-100):
        super(CrossEntropyLossV2, self).__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        losses = CrossEntropyFunctionV2.apply(
                logits, labels, self.ignore_index)
        #  losses = losses[labels != self.ignore_index]
        if self.reduction == 'sum':
            losses = losses.sum()
        elif self.reduction == 'mean':
            losses = losses.mean()
        return losses

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        net = torchvision.models.resnet18(pretrained=False)
        self.conv1 = net.conv1
        self.bn1 = net.bn1
        self.maxpool = net.maxpool
        self.relu = net.relu
        self.layer1 = net.layer1
        self.layer2 = net.layer2
        self.layer3 = net.layer3
        self.layer4 = net.layer4
        self.out = nn.Conv2d(512, 3, 3, 1, 1)
    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.bn1(feat1)
        feat3 = self.relu(feat2)
        #  feat4 = self.maxpool(feat3)
        feat5 = self.layer1(feat3)
        feat6 = self.layer2(feat5)
        feat7 = self.layer3(feat6)
        feat8 = self.layer4(feat7)
        feat9 = self.out(feat8)

        feat7.retain_grad()
        feat7.register_hook(lambda grad: grad*1000)
        return feat9, feat7

net1 = Model()
net2 = Model()
from copy import deepcopy
net2.load_state_dict(deepcopy(net1.state_dict()))

#  criteria1 = CrossEntropyLossV1(reduction='mean', ignore_index=255)
criteria1 = CrossEntropyLossV2(reduction='mean', ignore_index=255)
criteria2 = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)

net1.cuda()
net2.cuda()
net1.train()
net2.train()
criteria1.cuda()
criteria2.cuda()

optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)

bs = 32
for it in range(10):
    inten = torch.randn(bs, 3, 256, 256).cuda()
    lbs = torch.randint(0, 3, (bs, 16, 16)).cuda()
    #  net2.load_state_dict(deepcopy(net1.state_dict()))

    optim1.zero_grad()
    logits, feat = net1(inten.clone())
    loss1 = criteria1(logits, lbs.clone())
    loss1.backward()
    print('feat.grad1', feat.grad[0, :4, 0, 0])
    optim1.step()

    logits, feat = net2(inten.clone())
    optim2.zero_grad()
    loss2 = criteria2(logits, lbs.clone())
    loss2.backward()
    print('feat.grad2', feat.grad[0, :4, 0, 0])
    optim2.step()
    print(loss1.item() - loss2.item())
    print()

The cuda code is like this:

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>

#include <iostream>

using std::cout;
using std::endl;

#define BLOCKSIZE 512

// kernel function for forward and backward
template<typename scalar_t>
__global__ void LSRLossForward(const int n_size,
                            const int dimsize, const int m_size,
                            const scalar_t *log_scores,
                            const int64_t *labels,
                            scalar_t *losses,
                            const int64_t ignore_index) {
    // shared memory
    __shared__ scalar_t sdata[BLOCKSIZE + 2];

    int tid = threadIdx.x;
    int bid = blockIdx.x;

    int samplesize = n_size * m_size;
    for (int i{bid}; i < samplesize; i+=gridDim.x) {
        int n_idx = i / m_size;
        int m_idx = i % m_size;
        int64_t lb = labels[i];
        if (lb == ignore_index) {
            if (tid == 0) losses[i] = 0;
            continue;
        } 

        int idx = n_idx * dimsize * m_size + lb * m_size + m_idx;
        if (tid == 0) losses[i] = -log_scores[idx];
    }
}

template<typename scalar_t>
__global__ void LSRLossBackward(const int n_size,
                            const int dimsize, const int m_size,
                            const scalar_t *grad,
                            scalar_t *grad_logits,
                            const scalar_t *scores,
                            const int64_t *labels,
                            const int64_t ignore_index) {
    int tid = threadIdx.x;
    int bid = blockIdx.x;

    int samplesize = n_size * m_size;
    for (int i{bid}; i < samplesize; i+=gridDim.x) {
        int n_idx = i / m_size;
        int m_idx = i % m_size;
        int64_t lb{labels[i]};
        for (int j{tid}; j < dimsize; j+=blockDim.x) {
            int idx = n_idx * dimsize * m_size + j * m_size + m_idx; 
            scalar_t gradval = 0; 
            if (lb != ignore_index) {
                gradval = scores[idx];
                if (j == lb) {
                    gradval -= 1.;
                }
            }
            grad_logits[idx] = gradval * grad[i];
        }
    }
}

// cuda forward and backward
at::Tensor LSR_forward_cuda(const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // CHECK type and shape
    AT_ASSERTM(logits.type().is_cuda(), "logits should be cuda");
    AT_ASSERTM(labels.type().is_cuda(), "labels should be cuda");

    const int n_size = logits.size(0);
    const int dimsize = logits.size(1);
    const int m_size = logits.numel() / (n_size * dimsize);
    const int samplesize = labels.numel();

    // allocate memory and cuda grid/block
    auto losses = torch::zeros_like(labels, logits.options());
    auto log_scores = torch::log_softmax(logits, 1);

    dim3 grid1(std::min(samplesize, (int)4096));
    dim3 block1(std::min(dimsize, (int)BLOCKSIZE));
    if (losses.numel() == 0) {
        THCudaCheck(cudaGetLastError());
        return losses;
    }

    // call kernel
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(losses.scalar_type(), "lsr forward", [&] {
        int shm_size = BLOCKSIZE * sizeof(scalar_t) * 2; 
        LSRLossForward<scalar_t><<<grid1, block1, shm_size, at::cuda::getCurrentCUDAStream()>>>(
            n_size, dimsize, m_size, 
            log_scores.contiguous().data<scalar_t>(), 
            labels.contiguous().data<int64_t>(), 
            losses.contiguous().data<scalar_t>(),
            ignore_index
        );
    });
    THCudaCheck(cudaGetLastError());
    return losses;
}

at::Tensor LSR_backward_cuda(const at::Tensor &grad,
                                  const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // CHECK type and shape
    AT_ASSERTM(grad.type().is_cuda(), "grad should be cuda");
    AT_ASSERTM(logits.type().is_cuda(), "logits should be cuda");
    AT_ASSERTM(labels.type().is_cuda(), "labels should be cuda");

    const int n_size = logits.size(0);
    const int dimsize = logits.size(1);
    const int m_size = logits.numel() / (n_size * dimsize);
    const int samplesize = labels.numel();

    // allocate memory and cuda grid/block
    auto grad_logits = torch::empty_like(logits);
    auto scores = torch::softmax(logits, 1);

    dim3 grid(std::min(samplesize, (int)4096));
    dim3 block(std::min(dimsize, (int)BLOCKSIZE));
    if (grad_logits.numel() == 0) {
        THCudaCheck(cudaGetLastError());
        return grad_logits;
    }

    // call kernel
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_logits.scalar_type(), "lsr backwrd", [&] {
        int shm_size = BLOCKSIZE * sizeof(scalar_t) * 2; 
        LSRLossBackward<scalar_t><<<grid, block, shm_size, at::cuda::getCurrentCUDAStream()>>>(
            n_size, dimsize, m_size, 
            grad.contiguous().data<scalar_t>(), 
            grad_logits.contiguous().data<scalar_t>(),
            scores.contiguous().data<scalar_t>(), 
            labels.contiguous().data<int64_t>(), 
            ignore_index
        );
    });
    THCudaCheck(cudaGetLastError());
    return grad_logits;
}

// python inferface
at::Tensor LSR_forward(const at::Tensor &logits,
                             const at::Tensor &labels,
                             const int64_t ignore_index
                             ) {
    if (!(logits.type().is_cuda() && labels.type().is_cuda())) {
        AT_ERROR("this LSR loss only supports gpu mode\n");
    } 
    at::DeviceGuard guard(logits.device());
    return LSR_forward_cuda(logits, labels, ignore_index);
}

at::Tensor LSR_backward(const at::Tensor &grad,
                                  const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // TODO: try AT_ASSERTM
    if (!(logits.type().is_cuda() && labels.type().is_cuda())) {
        AT_ERROR("this LSR loss only supports gpu mode\n");
    } 
    at::DeviceGuard guard(logits.device());
    return LSR_backward_cuda(grad, logits, labels, ignore_index);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("lsr_forward", &LSR_forward, "lsr forward");
    m.def("lsr_backward", &LSR_backward, "lsr backward");
}

My problem is that, the difference of gradient and loss between the cuda extension implementation and the nn.CrossEntropyLoss gets too big after about 10 iters. Though the gradient of the cuda implementation of the first iter is same with nn.CrossEntropyLoss. How could I solve this problem please?

ClementPinard commented 4 years ago

Are you sure about your kernel function LSRLossForward ? It seems that sdata is never used, and nothing is done if tid is not 0. Not sure this is the root of your problem, but it's worth a try

CoinCheung commented 4 years ago

Yes, my original logic needs shared memory so I left it there unused, and so it is with the tid problem. However, these two places seem to make code and logic simpler, and after removing them the problem still seems to exist.

CoinCheung commented 4 years ago

I solved the problem, it is some problem with my implementation. I am closing this. Thanks for support !!!