Closed kkraus14 closed 4 years ago
Running your 3rd example with cudf and RMM 0.14:
(cudf_dev_10.2) mharris@dgx02:~/rapids$ nvprof python test.py
==17274== NVPROF is profiling process 17274, command: python test.py
Press enter to continue
Press enter to exit
==17274== Profiling application: python test.py
==17274== Profiling result:
No kernels were profiled.
Type Time(%) Time Calls Avg Min Max Name
API calls: 99.30% 2.26454s 2 1.13227s 473.90ms 1.79063s cudaMalloc
0.37% 8.5109ms 8 1.0639ms 1.0488ms 1.0700ms cuDeviceTotalMem
0.27% 6.2377ms 776 8.0380us 113ns 321.28us cuDeviceGetAttribute
0.03% 641.85us 2 320.92us 275.21us 366.64us cudaFree
0.02% 526.16us 8 65.770us 62.850us 72.260us cuDeviceGetName
0.00% 18.121us 8 2.2650us 1.3600us 4.8100us cuDeviceGetPCIBusId
0.00% 17.087us 2 8.5430us 8.3120us 8.7750us cudaStreamSynchronize
0.00% 5.8310us 16 364ns 154ns 959ns cuDeviceGet
0.00% 1.6290us 8 203ns 159ns 251ns cuDeviceGetUuid
0.00% 1.4480us 3 482ns 221ns 686ns cuDeviceGetCount
So as you can see only the two cudaMalloc calls corresponding to the two DeviceBuffer allocations.
So the extra memory must either be do to global device arrays or context overhead.
I believe this is due to the size of the device code in the module.
(cudf_dev_10.2) mharris@dgx02:~/rapids/cudf/cpp/build/release$ cuobjdump -res-usage libcudf.so | grep Function | wc -l
9788
There are 9788 kernel functions in libcudf.so! Removing the legacy API will help, but probably not much. This is going to be an ongoing problem.
We might be able to investigate dividing libcudf into smaller .so
s that are loaded on-demand by the main library.
Is it possible to figure out which kernels are taking up the most space?
Just curious, why is this a problem? 300MB doesn't seem like a lot of memory to be taken up in device.
Workloads people are pushing these days are saturating available memory (and often OOMing). Every bit counts.
Is it possible to figure out which kernels are taking up the most space?
Here are the sizes of the object files generated. I'm not sure if this is supposed to match the amount of space taken up in the final binary. Unsurprisingly, reductions code takes the longest to compile too.
Just curious, why is this a problem? 300MB doesn't seem like a lot of memory to be taken up in device.
It came up as a problem when trying to use multiple processes per GPU without using MPS. Each process then needs its own context, so it ends up being # processes * 300MB. That can easily be solved by using MPS, which should be done anyways for other reasons.
For anyone wondering why reductions are so expensive, it's because we have to instantiate N N K kernels where N is the number of types we support, and K is the number of reduction operators we support.
Maybe we need to switch reductions over to using Jitify as well.
What about strings (since it is next on the list)?
Within reductions, the most space is taken by MIN and MAX, presumably because they work on strings as well.
$ ls -lhS
total 76M
-rw-rw-r-- 1 dmakkar dmakkar 16M Mar 28 05:01 min.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 16M Mar 28 05:01 max.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 9.5M Mar 28 05:01 scan.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 9.1M Mar 28 05:00 sum_of_squares.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 8.8M Mar 28 05:00 sum.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 8.6M Mar 28 05:00 product.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 3.1M Mar 28 05:00 std.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 3.1M Mar 28 05:00 var.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 2.8M Mar 28 05:00 mean.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 25K Mar 28 04:59 reductions.cpp.o
-rw-rw-r-- 1 dmakkar dmakkar 22K Mar 28 04:59 all.cu.o
-rw-rw-r-- 1 dmakkar dmakkar 22K Mar 28 04:59 any.cu.o
And since we don't support strings in JIT code, we have to create a separate compiled version of these just for strings, similar to how we did for binary ops.
One question though, why the discrepancy between output type for groupby reduction and column reduction. Groupby has a fixed 1:1 input-output type mapping determined by target_type_t
. Why not follow that for reductions? Especially since the output is a scalar. It shouldn't be too expensive to cast the result scalar to a different type, right? Low hanging fruit.
One question though, why the discrepancy between output type for groupby reduction and column reduction. Groupby has a fixed 1:1 input-output type mapping determined by
target_type_t
. Why not follow that for reductions?
Because doing the same thing with groupby is effectively impossible. It's possible for reductions, so we do it there.
I'm fine with fixing the output type for reductions if we can still satisfy the needs of users with that approach.
And since we don't support strings in JIT code, we have to create a separate compiled version of these just for strings, similar to how we did for binary ops.
Why don't we support strings in JIT code?
It's possible for reductions, so we do it there.
I’m just wondering if the only difference is that we get a scalar of the desired type, Or am I missing something. And that casting the result will yield different result than specifying type at launch.
We can get more accurate size profiling by using cuobjdump. cuobjdump -xelf all libcudf.so
dumps all the cubins. Unfortunately they have unhelpful names like libcudf.165.sm_70.cubin
.
Here's a list of the top 10 (there are 201 of them when compiled just for sm_70!) in order of size.
-rw-r--r-- 1 mharris nvidia 13204328 Mar 29 14:48 libcudf.137.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8961000 Mar 29 14:48 libcudf.166.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8716392 Mar 29 14:48 libcudf.164.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8562920 Mar 29 14:48 libcudf.49.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8553448 Mar 29 14:48 libcudf.50.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8310376 Mar 29 14:48 libcudf.139.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 8310248 Mar 29 14:48 libcudf.138.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 7736680 Mar 29 14:48 libcudf.168.sm_70.cubin
-rw-r--r-- 1 mharris nvidia 6939112 Mar 29 14:48 libcudf.115.sm_70.cubin
A bit of unix command line composition confirms that the actual code size in the top cabin is close to the file size:
> cuobjdump -elf libcudf.165.sm_70.cubin | grep 'PROGBITS.*\.text' | tr -s ' ' | cut -d ' ' -f4-4 | paste -sd+ - | tr '[a-z]' '[A-Z]' | (echo -n "ibase=16; " && cat) | bc
16626304
If we look in the elf contents for the .text
sections, we can see that there are only two real functions that are the culprits here: three different versions of thrust::scan
and thrust::for_each
. The scans are each > 4MB (the size is the 3rd integer in each line -- 0x48f380 == 4780928).
22 18080 b9b80 0 80 PROGBITS 6 3 76000046 .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_14__parallel_for16ParallelForAgentINS0_10for_each_fINS_17counting_iteratorIiNS_11use_defaultES7_S7_EENS_6detail16wrapped_functionIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm10128EEEvEEEEiEESI_iEEvT0_T1_
23 d1c00 48f380 0 80 PROGBITS 100006 3 90000047 .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_6__scan9ScanAgentINS_18transform_iteratorIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm10128EEENS_17counting_iteratorIiNS_11use_defaultESD_SD_EESD_SD_EEPiNS_4plusIiEEiiNS_6detail17integral_constantIbLb1EEEEESF_SG_SI_iNS0_3cub13ScanTileStateIiLb1EEENS3_9DoNothingIiEEEEvT0_T1_T2_T3_T4_T5_
24 560f80 b9b80 0 80 PROGBITS 6 3 76000048 .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_14__parallel_for16ParallelForAgentINS0_10for_each_fINS_17counting_iteratorIiNS_11use_defaultES7_S7_EENS_6detail16wrapped_functionIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm1104EEEvEEEEiEESI_iEEvT0_T1_
25 61ab00 48f380 0 80 PROGBITS 100006 3 90000049 .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_6__scan9ScanAgentINS_18transform_iteratorIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm1104EEENS_17counting_iteratorIiNS_11use_defaultESD_SD_EESD_SD_EEPiNS_4plusIiEEiiNS_6detail17integral_constantIbLb1EEEEESF_SG_SI_iNS0_3cub13ScanTileStateIiLb1EEENS3_9DoNothingIiEEEEvT0_T1_T2_T3_T4_T5_
26 aa9e80 b9b80 0 80 PROGBITS 6 3 7600004a .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_14__parallel_for16ParallelForAgentINS0_10for_each_fINS_17counting_iteratorIiNS_11use_defaultES7_S7_EENS_6detail16wrapped_functionIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm112EEEvEEEEiEESI_iEEvT0_T1_
27 b63a00 300 0 80 PROGBITS 6 3 c00004b .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_14__parallel_for16ParallelForAgentINS0_11__transform17unary_transform_fINS_10device_ptrIKiEEPiNS5_14no_stencil_tagENS_8identityIiEENS5_21always_true_predicateEEElEESF_lEEvT0_T1_
28 b63d00 48f380 0 80 PROGBITS 100006 3 9000004c .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_6__scan9ScanAgentINS_18transform_iteratorIN4cudf7strings6detail69_GLOBAL__N__45_tmpxft_000085a6_00000000_6_backref_re_cpp1_ii_c02a601111backrefs_fnILm112EEENS_17counting_iteratorIiNS_11use_defaultESD_SD_EESD_SD_EEPiNS_4plusIiEEiiNS_6detail17integral_constantIbLb1EEEEESF_SG_SI_iNS0_3cub13ScanTileStateIiLb1EEENS3_9DoNothingIiEEEEvT0_T1_T2_T3_T4_T5_
29 ff3080 180 0 80 PROGBITS 6 3 a00004d .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_6__scan9InitAgentINS0_3cub13ScanTileStateIiLb1EEEiEES7_iEEvT0_T1_
2a ff3200 80 0 80 PROGBITS 6 3 400004e .text._ZN6thrust8cuda_cub4core13_kernel_agentINS0_14__parallel_for16ParallelForAgentINS0_10for_each_fINS_10device_ptrINS_4pairIiiEEEENS_6detail16wrapped_functionINSA_23allocator_traits_detail5gozerEvEEEElEESF_lEEvT0_T1_
2b ff3280 80 0 80 PROGBITS 6 3 400004f .text._ZN6thrust8cuda_cub3cub11EmptyKernelIvEEvv
Demangling the first function name above we get:
void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::for_each_f<thrust::counting_iterator<int, thrust::use_default, thrust::use_default, thrust::use_default>, thrust::detail::wrapped_function<cudf::strings::detail::(anonymous namespace)::backrefs_fn<10128ul>, void> >, int>, thrust::cuda_cub::for_each_f<thrust::counting_iterator<int, thrust::use_default, thrust::use_default, thrust::use_default>, thrust::detail::wrapped_function<cudf::strings::detail::(anonymous namespace)::backrefs_fn<10128ul>, void> >, int>(thrust::cuda_cub::for_each_f<thrust::counting_iterator<int, thrust::use_default, thrust::use_default, thrust::use_default>, thrust::detail::wrapped_function<cudf::strings::detail::(anonymous namespace)::backrefs_fn<10128ul>, void> >, int)
Searching the code base I find backrefs_fn
in two files: one legacy (cudf/cpp/custrings/strings/replace_backref.cu
) and one libcudf++ file: (cudf/cpp/strings/replace/backref_re.cu
). Only the former has both thrust::scan
and thrust::for_each
, so this .cubin must be due to legacy.
That said, these are very large kernels. Doing similar analysis on the rest of the .cubins will be a bit tedious.
Is it really a significant benefit to have per-type kernels for things like min/max reductions ? (it seems like it would only save a non-divergent branch within the kernel)
If you are suggesting moving type dispatch inside the reduction, I agree that's certainly an option to explore. @karthikeyann explored that for binops (can't find the PR right now, on mobile).
It's possible for reductions, so we do it there.
I’m just wondering if the only difference is that we get a scalar of the desired type, Or am I missing something. And that casting the result will yield different result than specifying type at launch.
I answered my own question. The type being specified at launch is going to follow the behaviour of thrust and stl's reduce functionalities:
T thrust::reduce(InputIterator first, InputIterator last, T init, BinaryFunction binary_op);
T std::accumulate( InputIt first, InputIt last, T init, BinaryOperation op );
where the input iterator is first converted to type T and then passed to the binary operation functor.
Meaning, calling cudf::reduce(sum, int)
on [1.5, 2.0, 3.6]
should result in 1 + 2 + 3 = 6l
rather than (int)(1.5 + 2.0 + 3.6) = (int)7.1 = 7
.
So that's a dead end. Pity because when I tried it, fixing the output type gave a big improvement in compile time (1m44s -> 27s) and binary size (68MB -> 12.2 MB)
If you are suggesting moving type dispatch inside the reduction, I agree that's certainly an option to explore. @karthikeyann explored that for binops (can't find the PR right now, on mobile).
compiled binary ops experiment PR https://github.com/rapidsai/cudf/pull/4269
For N is the number of types we support, and K is the number of reduction operators we support, we have N*N*K
reduction kernels because of host side type dispatch. That another N
term comes here because of type conversion which can be removed by doing a device side dispatch. Since it's just static_cast
, there is no significant code change between these kernels. (literally it could be just 1 ptx instruction difference among these N
kernels).
That said, these are very large kernels. Doing similar analysis on the rest of the .cubins will be a bit tedious.
whole libcudf.so can be analysed and demangled like this
cuobjdump -elf libcudf.so | grep 'PROGBITS.*\.text' | sed "s/.text.//" | c++filt > demangled_list
cat demangled_list | tr -s ' ' | perl -e 'print sort {hex((split(/\s+/,$a))[3]) <=> hex((split(/\s+/,$b))[3])} <>;' \
| sed 's/, thrust::null_type//g' | sed 's/, thrust::use_default//g' > sorted_list
grep -Po "void .*?<.*?<" sorted_list | sort | uniq -c | sort -nr | head
1989 void parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<
1676 void cub::DeviceReduceSingleTileKernel<cub::DeviceReducePolicy<
1548 void cub::DeviceReduceKernel<cub::DeviceReducePolicy<
509 void parallel_for::ParallelForAgent<thrust::cuda_cub::for_each_f<
250 void cudf::experimental::detail::(anonymous namespace)::gpu_rolling<cudf::detail::timestamp<
244 void scan::ScanAgent<thrust::transform_iterator<
180 void cudf::experimental::detail::(anonymous namespace)::copy_if_else_kernel<256, cudf::detail::timestamp<
122 void parallel_for::ParallelForAgent<thrust::cuda_cub::__uninitialized_fill::functor<
115 void cudf::unary::gpu_op_kernel<cudf::detail::wrapper<
105 void cudf::experimental::detail::valid_if_n_kernel<thrust::counting_iterator<
lbicudf.so PROGBITS sum = 212421344
PROGBITS cub::DeviceReduceSingleTileKernel
and cub::DeviceReduceKernel
= 100*(34956928+33131392)/212421344 = 32.05%
libcudf.so size = 352MB
libcudf.so PROGBITS sum = 199375072
PROGBITS cub::DeviceReduceSingleTileKernel
and cub::DeviceReduceKernel
= 100*(28852608+26571904)/199375072 = 27.79%
libcudf.so size = 318MB (34MB less)
But unfortunately, it did not have significant effect on idle device memory usage (This BUG!) For Example 3) code (running on 12 core machine with GV100 sm_70 GPU) Old code: rmm (440 MB) + libcudf = 708 MB New code: (reduction device dispatch) rmm (440 MB) + libcudf = 692 MB so, 16MB less = 5.97% less
Here is the top kernel types by count
grep -Po "void .*?<.*?<" sorted_list | sort | uniq -c | sort -nr | head
1989 void parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<
1044 void cub::DeviceReduceSingleTileKernel<cub::DeviceReducePolicy<
916 void cub::DeviceReduceKernel<cub::DeviceReducePolicy<
437 void parallel_for::ParallelForAgent<thrust::cuda_cub::for_each_f<
250 void cudf::experimental::detail::(anonymous namespace)::gpu_rolling<cudf::detail::timestamp<
244 void scan::ScanAgent<thrust::transform_iterator<
180 void cudf::experimental::detail::(anonymous namespace)::copy_if_else_kernel<256, cudf::detail::timestamp<
122 void parallel_for::ParallelForAgent<thrust::cuda_cub::__uninitialized_fill::functor<
115 void cudf::unary::gpu_op_kernel<cudf::detail::wrapper<
105 void cudf::experimental::detail::valid_if_n_kernel<thrust::counting_iterator<
Just out of curiosity, what's the size difference in debug build?
I don't think debug builds are succeeding currently, right? Waiting to remove legacy to see if that fixes the relocation problem...
Just out of curiosity, what's the size difference in debug build?
2.0GB libcudf.so debug build size.
Example 3) takes 3026MiB
In working with rmm and cudf, if one only uses rmm we observe a lower device memory usage than if we load the libcudf shared library and only use rmm.
CUDA runtime loads the cubin files (modules) to GPU on first invocation of CUDA runtime APIs. so even if libcudf APIs are not used, the loaded cubin files in libcudf.so are copied to GPU (per process) as soon as any CUDA call is invoked.
Loading/unloading cubin to GPU memory can be controlled by cuda driver APIs. but it seems excessive idea.
Similar to NPP, we can try approach of splitting the library to sub-libraries. (multiple .so files) and load only the necessary sub-libraries as @harrism suggested.
I think we are premature optimizing here. Once we remove legacy APIs and NVStrings/NVCategory, I'd expect the library size should go down significantly.
Yeah, that's the plan and hope. This is all just discussion until then.
Old code: rmm (440 MB) + libcudf = 708 MB
After legacy code removal, Example (3) takes, rmm (440 MB) + libcudf = 618 MB
I got different numbers. When I run Keiths last example above, before loading libcudf, nvidia-smi reports:
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 3 9656 C .../conda/cuda_10.2/envs/rapids/bin/python 279MiB |
+-----------------------------------------------------------------------------+
After importing libcudf and creating a 5-byte DeviceBuffer
it reports:
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 3 9656 C .../conda/cuda_10.2/envs/rapids/bin/python 417MiB |
+-----------------------------------------------------------------------------+
So that means importing libcudf takes 138 MiB, down from 316 MiB in Keith's test!
@kkraus14 how low do we need to go in order to close this bug? Is there a nice way to write a test or benchmark that we can use to track this?
@kkraus14 how low do we need to go in order to close this bug? Is there a nice way to write a test or benchmark that we can use to track this?
I don't think this is even a bug anymore, was just odd behavior I noticed that I lacked understanding on. I'm perfectly happy closing this as is.
In theory we could build a unit test to handle this, but having something reliable in CI would likely be tricky.
Will close for now, since we'll be tracking this along with compile time reductions.
In working with rmm and cudf, if one only uses rmm we observe a lower device memory usage than if we load the libcudf shared library and only use rmm. Apologies for the Python only examples:
Example 1, not loading libcudf.so, device memory usage ~335 MiB:
Example 2, loading libcudf.so, device memory usage ~651 MiB:
Example 3, construct and destruct buffer then load libcudf.so then construct and destruct buffer again, ~335 MiB in first construct/destruct, ~651 MiB after second construct/destruct: