ledatelescope / bifrost

A stream processing framework for high-throughput applications.
BSD 3-Clause "New" or "Revised" License
64 stars 29 forks source link

Use after free and other issues with JIT #209

Closed torrance closed 1 year ago

torrance commented 1 year ago

There's some issues around copying the device code used in JIT that assume the null character will never appear in the PTX (or ROCM) code.

Case 1:

https://github.com/ledatelescope/bifrost/blob/30fd2682ec22bc26d9dd8cd9bc424aefc4e7075e/src/map.cpp#L386-L398

Here the vector vptx is initialised and loaded with the device code, then the pointer to the first value is taken and assigned to std::string ptx_string, implicitly calling the std::string(char*) constructor.

This is wrong in 3 ways: 1. it requires reading past the end of vptx to find a null terminating character; 2. it assumes the uninitialised memory passed that address is null; and 3 it assumes that the null character can never be part of the device code.

Case 2:

https://github.com/ledatelescope/bifrost/blob/30fd2682ec22bc26d9dd8cd9bc424aefc4e7075e/src/cuda.hpp#L137-L148

The calling function is calling std::string::c_str() which is then being copied into a string via the implicit std::string(char*) constructor. This will fail if the null character forms part of the device code.

The following patch remedies the segfaults I've been seeing:

diff --git a/src/cuda.hpp b/src/cuda.hpp
index 27cce52..6281174 100644
--- a/src/cuda.hpp
+++ b/src/cuda.hpp
@@ -143,8 +143,8 @@ public:
            this->create_module();
        }
    }
-   inline CUDAKernel(const char*   func_name,
-                     const char*   ptx,
+   inline CUDAKernel(const std::string   func_name,
+                     const std::string   ptx,
                      unsigned int  nopts=0,
                      hipJitOption* opts=0,
                      void**        optvals=0) {
@@ -153,8 +153,8 @@ public:
        _opts.assign(opts, opts + nopts);
        this->create_module(optvals);
    }
-   inline CUDAKernel& set(const char*   func_name,
-                          const char*   ptx,
+   inline CUDAKernel& set(const std::string   func_name,
+                          const std::string   ptx,
                           unsigned int  nopts=0,
                           hipJitOption* opts=0,
                           void**        optvals=0) {
diff --git a/src/map.cpp b/src/map.cpp
index 001915f..9187ea0 100644
--- a/src/map.cpp
+++ b/src/map.cpp
@@ -389,16 +389,15 @@ BFstatus build_map_kernel(int*                 external_ndim,
    size_t ptxsize;
    BF_CHECK_NVRTC( nvrtcGetPTXSize(program, &ptxsize) );
    std::vector<char> vptx(ptxsize);
-   char* ptx = &vptx[0];
-   BF_CHECK_NVRTC( nvrtcGetPTX(program, &ptx[0]) );
+   BF_CHECK_NVRTC( nvrtcGetPTX(program, vptx.data()) );
    BF_CHECK_NVRTC( nvrtcDestroyProgram(&program) );
 #if BF_DEBUG_ENABLED
    if( EnvVars::get("BF_PRINT_MAP_KERNELS_PTX", "0") != "0" ) {
-       std::cout << ptx << std::endl;
+       std::cout << vptx.data() << std::endl;
    }
 #endif
-   *ptx_string = ptx;
-   *kernel_name_ptr = kernel_name;
+   *ptx_string = {vptx.begin(), vptx.end()};
+   *kernel_name_ptr = {kernel_name.begin(), kernel_name.end()};
    // TODO: Can't do this, because this function may be cached
    //         **Clean this up!
    //*external_ndim = ndim;
@@ -731,7 +730,7 @@ BFstatus bfMap(int                  ndim,
                                      &ptx, &kernel_name));
        }
        CUDAKernel kernel;
-       BF_TRY(kernel.set(kernel_name.c_str(), ptx.c_str()));
+       BF_TRY(kernel.set(kernel_name, ptx));
        kernel_cache.insert(cache_key,
                            std::make_pair(kernel, basic_indexing_only));
 #if defined(BF_MAP_KERNEL_DISK_CACHE) && BF_MAP_KERNEL_DISK_CACHE
benbarsdell commented 1 year ago

I'm not sure what's causing your segfaults, but PTX is text, not binary, so it is null terminated and cannot contain null characters. nvrtcGetPTX[Size] includes the null terminator.

torrance commented 1 year ago

Righto, then it's only a problem if you're hipifying the code, as hiprtcGetCode() will return binary when run on a rocm backend.

I will leave the patch as part of the full HIP patch-set then.