bfGraph / STGraph

🌟 Vertex Centric approach for building GNN/TGNNs
MIT License
18 stars 0 forks source link

cuModule Load 700 error #42

Closed JoelMathewC closed 1 year ago

JoelMathewC commented 1 year ago

When running the code on Google Colab or an external GPU we run into a cuModuleLoad error. This issue will attempt to fix this issue. The current status with regards to dynamic models is as follows

  1. Naive and PCSR does not work
  2. GPMA does not work

By commenting out lines in the template we have isolated the problem in the template to this region.

// Within the inner for loop
{%for agg_stmt in aggs%}
{{agg_stmt.compute}}
{{agg_stmt.inner_write}}
{%endfor%}

On a generated kernel for TGCN this translates to the following. Both variables are of datatype float.

V12_tmp += V11_tmp
JoelMathewC commented 1 year ago

Referring to the cuModuleLoad docs available here

The CUDA driver API does not attempt to lazily allocate the resources needed by a module; 
if the memory for functions and data (constant and global) needed by the module cannot be allocated, 
cuModuleLoad() fails.

So the module must be struggling to allocate resources.

  1. Wonder why it works for gpma.
  2. Also it is quite weird that commenting out an statement not involved in allocation is solving our problem
JoelMathewC commented 1 year ago

There is this discussion that talks about possible hurdles of compilation and 64bit addressing. While all of that seems close to what we are looking for it just isn't quite it because again it does not answer the question of why exactly there uncommenting that one float line helps solve the problem.

JoelMathewC commented 1 year ago

I tried loading the module separately and it seems to work. Which makes this even more odd. The code is as given below

PTX_PATH = "egl_kernel.ptx"
cuda = CDLL('libcuda.so')
cuda.cuInit(0)

device = c_int()
cuda.cuDeviceGet(byref(device), 0)
cu_context = c_void_p()
cuda.cuCtxCreate(byref(cu_context), 0, device)

cu_module = c_void_p()
char_p = c_char_p((PTX_PATH).encode())
ret = cuda.cuModuleLoad(byref(cu_module), char_p)
if ret:
    raise Exception('cuModuleLoad fails with ret ' + str(ret))

# ret is 0

Additionally, I added the ContextCreate aspect and it still doesn't work. But there is a general query here though where exactly is the cuInit and cuCtxCreate happening in Seastar?

After looking into it a bit it can be seen that cuInit is happening in the deviceInfo function. However there doesn't seem to be a call to cuCtxCreate anywhere.

Update

Did some testing by commenting out the cuInit from the deviceinfo and it seems to have no impact on the result (cuResult is still 700). This implies that for cuInit and cuCtxCreate is happening somewhere else in Seastar.

After going through the pytorch docs I realize that pytorch handles this internally as specified here which solves that mystery. This can be verified by commenting out cuInit and cuCtxCreate and running the following code before invoking the cuModuleLoad

import torch
a = torch.tensor([1,2], device="cuda:0")

However there is still the question of why we Seastar fails to do a cuModuleLoad

JoelMathewC commented 1 year ago

Starting over again

The following code is capable of loading the module correctly.

# Trigger torch's cuInit and cuCtxCreate
import torch
a = torch.tensor([1,2], device="cuda:0")

# cuModuleLoad
PTX_PATH = "egl_kernel.ptx"
cuda = CDLL('libcuda.so')

# Same device info as in seastar with the cuInit commented out
device = deviceinfo()

cu_module = c_void_p()
char_p = c_char_p((PTX_PATH).encode())
ret = cuda.cuModuleLoad(byref(cu_module), char_p)
if ret:
    raise Exception('cuModuleLoad fails with ret ' + str(ret))

Seastar follows a similar approach but fails to load the module.

JoelMathewC commented 1 year ago

This is a bit of a shot in the dark but there is a question as to why the kernel variables are V22 and all when there are actually only 3-4 variables

JoelMathewC commented 1 year ago

Okay another important point is as follows, a TGCN needs 3 GCNs, the seastar module for the first GCN is loaded successfully. The seastar module for the second GCN fails to load and I can't seem to understand why?\

A start

Since one cuModuelLoad works I think its possible to say that the issue is not with the function or the kernel. There is probably some other external issue. One possible issue is since it is executing the first time maybe some threads there are failing and that error is being reflected only when the second load is called. We can verify this by changing the TGCN to contain only a single GCN and see if that works.

Tested it out and we have hit an error 🥳🥳. We are running into a cuResult 700 which as per the cuda manual is CUDA_ERROR_ILLEGAL_ADDRESS. After using a few print statements we can isolate that this error is triggered for the first cuLaunchKernel itself.

Finding 1 So the error is actually from a cuLaunchKernel. This means as suggested by Unnikrishnan Sir and Jude this could be because of invalid memory.

However, there is still the odd case of why it works when the float addition is commented out. One more test is to see how cuLaunchKernel is affected by the commenting out.

And turns out that works.

Finding 2 So the error have something to do with the float aggregation during a cuLaunchKernel.

JoelMathewC commented 1 year ago

CUDA Python Migrate

To help make debugging this easier, we have shifted to CUDA python. The same 700 error has been replicated in the CUDA python version.

JoelMathewC commented 1 year ago

Okay might have figured out why commenting out the float addition allows the kernel to run. I believe it has something to do with Deac Code Elimination taking place.

extern "C" __global__ void K0
(float *Vhinb, float *Vnormcen, float *Vnorminb, float *Vweight, float *V3, 
  unsigned int *row_offsets,
  unsigned int *eids,
  unsigned int *column_indices,
  int num_nodes,
  int max_dimx,
  int max_dimy,
  int thrs_per_group,
  int nodes_per_block) {

    int dst_id = nodes_per_block*blockIdx.x + threadIdx.x/thrs_per_group;

    if (dst_id < num_nodes) {

        int feat_len = max_dimx * max_dimy;
        int beg = __ldg(row_offsets + dst_id);
        int end = __ldg(row_offsets + dst_id + 1);
        int tx = threadIdx.x % thrs_per_group;

        for (; tx<feat_len; tx+=blockDim.x) {

            float V2_tmp = 0;
            int offset3 = dst_id * 32 + tx;int offset4 = dst_id * 1 + tx/32;

            for (int e=beg;e<end;++e) {

                int src_id = __ldg(column_indices + e);
                int eid = __ldg(eids + e);

                int offset0 = src_id * 32 + tx;int offset1 = src_id * 1 + tx/32;int offset2 = eid * 1 + tx/32;

                float V0_tmp = Vnorminb[offset1]*Vhinb[offset0];

                float V1_tmp = V0_tmp*Vweight[offset2];

               // This is what needs to be commented out
                V2_tmp += V1_tmp;

            }

            float V3_tmp = V2_tmp*Vnormcen[offset4];
            V3[offset3] = V3_tmp;

        }
    }
}

If the specified line is commented out, all of the other code in that for loop becomes dead code. nvcc will prune dead code as per StackOverflow. However this is just a theory.

Verification

This can be verified by disabling optimizations as shown here. The flags are

-g -G -Xcompiler -O0 -Xptxas -O0 -lineinfo -O0

Okay I've verified it and the theory is right. So the commenting was working before because of optimization. Thats another mystery solved.

JoelMathewC commented 1 year ago

Moving Forward

Started looking into why exactly there is an invalid memory error. I think the problem is with the eids. Verified this by changing the kernel as follows and ran into the same error.

// Only showing the relevant parts
for (int e=beg;e<end;++e) {

                int src_id = __ldg(column_indices + e);
                int eid = __ldg(eids + e);

                int offset0 = src_id * 1 + tx/32;int offset1 = src_id * 32 + tx;int offset2 = eid * 1 + tx/32;
                float V0_tmp = Vnorminb[offset0]*Vhinb[offset1];
                float V1_tmp = V0_tmp*Vweight[offset2];

                V2_tmp = eid;

            }

However after verifying the Naive code I don't seem to see an error with the approach. I also verified the pointer locations and it seems to pan out. I think the clue that it works on a local laptop but not on a GPU server is key to solving this issue, don't know where exactly to add that missing piece.

We can try sorting this out py using the pointers to print out the values in the locations through python for verification purposes.

JoelMathewC commented 1 year ago

Okay did a small test by changing the statement to V2_tmp = beg; and it seems to work. So that means that the row offset is being loaded correctly.

Tried change the statement to V2_tmp = __ldg(column_indices + e); and that seems to be a problem. Similarly eids is also a problem.

Making the following modification causes the module to run successfully

...
for (int e=beg;e<end;++e) {
                int dst_id = __ldg(column_indices + e);
                int eid = __ldg(eids + e);
                int offset0 = dst_id * 1 + tx/32;int offset1 = dst_id * 32 + tx;int offset2 = eid * 1 + tx/32;
                float V5_tmp = V4[offset1]*Vnormcen[offset0];
                float V7_tmp = V5_tmp*Vweight[offset2];

                if (e < 127){
                    V8_tmp = __ldg(column_indices + e);
                }
}
...
JoelMathewC commented 1 year ago

Serious issue resolved

Identified a serious issue, in NaiveGraph the backward edge list was created using the forward edge list but due to the lack of a deep copy the forward edge list was getting edited and this affected the correctness of CSR computation. This was corrected by introducing the following change

def _prepare_edge_lst_bwd(self, edge_list):    
        self.bwd_edge_list = []
        for i in range(len(edge_list)):
           // change
            edge_list_for_t = copy.deepcopy(edge_list[i])
            edge_list_for_t.sort()
            self.bwd_edge_list.append(edge_list_for_t)
JoelMathewC commented 1 year ago

Last Call

The only line of clues to follow is that of the pointer not being valid. To verify that we will have to pass the pointer from python to the C++ module. Convert the address to a thrust device pointer, move that to the host and then print that out and verify the validity. Given that the the code runs at the moment it does seem that memory continues to remain allocated and is not freed however we have no proof that the contents of row_offset or col_indices or eids is valid. It would help to perform this sanity check to verify that. Might also learn a thing or two about PyBind11.

Switched to uintptr_t to represent the pointers and referred to this for casting. Then created the following function in csr.cu

void print_dev_array(std::uintptr_t ptr, int size)
{

    int *dev_ptr = reinterpret_cast<int *>(ptr);
    int *host_ptr = (int *)malloc(size * sizeof(int));
    // thrust::device_ptr<int> dev_ptr = thrust::device_pointer_cast(raw_ptr);
    cudaMemcpy(host_ptr, dev_ptr, size * sizeof(int), cudaMemcpyDeviceToHost);
    std::cout << "\nPRINTING ARRAY\n";
    for (int i = 0; i < size; ++i)
    {
        std::cout << host_ptr[i] << " ";
    }
    std::cout << "\n";
}

When attempting to print the row_offset after the creation of the CSR using the function above the result is as follows

191889408 32528 191889408 32528 -50341456 22078 -50341456 22078 0 0 0 0 -54015632 22078 0 0 2097152 0 
-54015632 22078 0 0 0 0 0 0 0 0 0 0 25574 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -50341280 22078 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 

The expected result is

0 40 53 65 90 107 130 150 152 173 189 216 226 232 247 261 274 290 312 319 333 352 361 396 420 423 427 445 468 495
 516 527 541 561 572 591 625 668 737 750 753 774 804 818 830 845 860 870 901 918 931 943 957 970 979 993 1009 
1035 1050 1059 1072 1080 1098 1117 1130 1141 1158 1178 1193 1209 1225 1231 1241 1255 1265 1278 1288 1304 1318 1330 
1355 1375 1389 1434 1453 1479 1493 1504 1517 1550 1554 1572 1592 1618 1637 1648 1666 1690 1700 1722 1757 1771 
1777 1785 1809 1831 1841 1851 1857 1873 1881 1889 1904 1944 1964 1972 1981 2002 2014 2017 2023 2046 2072 2082 
2091 2108 2129 2130 2138 2158 

After Testing Results It seems that the pointer that is being passed to python and then back to the CSR C++ file fails to give the required results. We will have to look at why that is the case.

JoelMathewC commented 1 year ago

Finally something

So I modified the code such that in the CSR constructor after the device pointers are created in the get_csr_ptrs, the constructor calls another function to move the contents from device to host and then print it out. The output is garbage values. These garbage values match exactly with those values produced as output after the pointer is passed to python.

Inference PyBind11 is not messing with the pointers but somehow the device vectors are getting cleared.

A Fix

I moved the following definitions into the CSR Object and it worked 🥳🥳🥳

DEV_VEC row_offset_device;
DEV_VEC column_indices_device;
DEV_VEC eids_device;

Turns out since the device_vector object went out of scope the destructor was called and the memory was cleared. A similar explanation is given here. [In the edit section of the approved response]

Additionally we can make the shift from size_t to uintptr_t as suggested here.

JoelMathewC commented 1 year ago

IT WORKS!

Now ran one forward propagation and we no longer have an illegal memory error. So that closes this issue.

It may be noted that there is still an error with the code preventing it from running to completion but a different issue will be created for that issue if required.