rapidsai / cudf

cuDF - GPU DataFrame Library
https://docs.rapids.ai/api/cudf/stable/
Apache License 2.0
8.23k stars 883 forks source link

Use "ranger" to prevent grid stride loop overflow #10368

Open nvdbaranec opened 2 years ago

nvdbaranec commented 2 years ago

(updated Aug 2023)

Background

We found a kernel indexing overflow issue, first discovered in the fused_concatenate kernels (https://github.com/rapidsai/cudf/issues/10333) and this issue is present in a number of our CUDA kernels that take the following form:

size_type output_index = threadIdx.x + blockIdx.x * blockDim.x;  
while (output_index < output_size) {
  output_index += blockDim.x * gridDim.x;
}

If we have an output_size of say 1.2 billion and a grid size that's the same, the following happens: Some late thread id, say 1.19 billion attempts to add 1.2 billion (blockDim.x * gridDim.x) and overflows the size_type (signed 32 bits).

We made a round of fixes in #10448, and then later found another instance of this error in #13838. Our first pass of investigation was not adequate to contain the issue, so we need to take another close look.

Part 1 - First pass fix kernels with this issue

Source file Kernels Status
copying/concatenate.cu fused_concatenate_kernel #10448
valid_if.cuh valid_if_kernel #10448
scatter.cu marking_bitmask_kernel #10448
replace/nulls.cu replace_nulls_strings #10448
replace/nulls.cu replace_nulls #10448
rolling/rolling_detail.cuh gpu_rolling #10448
rolling/jit/kernel.cu gpu_rolling_new #10448
transform/compute_column.cu compute_column_kernel #10448
copying/concatenate.cu fused_concatenate_string_offset_kernel #13838
replace/replace.cu replace_strings_first_pass
replace_strings_second_pass
replace_kernel
#13905
copying/concatenate.cu concatenate_masks_kernel
fused_concatenate_string_offset_kernel
fused_concatenate_string_chars_kernel
fused_concatenate_kernel (int64)
#13906
hash/helper_functions.cuh init_hashtbl #13895
null_mask.cu set_null_mask_kernel
copy_offset_bitmask
count_set_bits_kernel
#13895
transform/row_bit_count.cu compute_row_sizes #13895
multibyte_split.cu multibyte_split_init_kernel
multibyte_split_seed_kernel (auto??)
multibyte_split_kernel
#13910
IO modules: parquet, orc, json #13910
io/utilities/parsing_utils.cu count_and_set_positions (uint64_t) #13910
conditional_join_kernels.cuh compute_conditional_join_output_size
conditional_join
#13971
merge.cu materialize_merged_bitmask_kernel #13972
partitioning.cu compute_row_partition_numbers
compute_row_output_locations
copy_block_partitions
#13973
json_path.cu get_json_object_kernel #13962
tdigest compute_percentiles_kernel (int) #13962
strings/attributes.cu count_characters_parallel_fn #13968
strings/convert/convert_urls.cu url_decode_char_counter (int)
url_decode_char_replacer (int)
#13968
text/subword/data_normalizer.cu kernel_data_normalizer (uint32_t) #13915
text/subword/subword_tokenize.cu kernel_compute_tensor_metadata (uint32_t) #13915
text/subword/wordpiece_tokenizer.cu init_data_and_mark_word_start_and_ends (uint32_t)
mark_string_start_and_ends (uint32_t)
kernel_wordpiece_tokenizer (uint32_t)
#13915

Part 2 - Take another pass over more challenging kernels

Source file Kernels Status
null_mash.cuh subtract_set_bits_range_boundaries_kernel
valid_if.cuh valid_if_n_kernel
copy_if_else.cuh copy_if_else_kernel
gather.cuh gather_chars_fn_string_parallel
more? search gridDim.x or blockDim.x to find more examples

Part 3 - Use ranger to prevent grid stride loop overflow

Additional information

There are also a number of kernels that have this pattern but probably don't ever overflow because they are indexing by bitmask words. (Example) Additional, In this kernel, source_idx probably overflows, but harmlessly.

A snippet of code to see this in action:

size_type const size = 1200000000;
auto big = cudf::make_fixed_width_column(data_type{type_id::INT32}, size, mask_state::UNALLOCATED);  
auto x = cudf::rolling_window(*big, 1, 1, 1, cudf::detail::sum_aggregation{}); 

Note: rmm may mask out of bounds accesses in some cases, so it's helpful to run with the plain cuda allocator.

nvdbaranec commented 2 years ago

The fix in basically all of these cases is quite simple: just make the index a size_t

jrhemstad commented 2 years ago

I'd love to just add an algorithm to do this.

https://godbolt.org/z/hK95z7zff

harrism commented 2 years ago

Or even a simple range helper for for loops: https://github.com/harrism/hemi#simple-grid-stride-loops

harrism commented 2 years ago

The fix in basically all of these cases is quite simple: just make the index a size_t

I think the general approach should be:

  1. Use an algorithm (thrust:: or std::) if possible before ever writing a custom kernel -- this way you write a per-element device functor instead, and indexing is handled for you.
  2. If a custom kernel must be written, it should use device-side algorithms instead of raw loops.
  3. If a raw grid-stride loop is required and an existing algorithm won't work, we should provide utilities to abstract the iteration and or the range and use auto for the type to avoid these mistakes.
github-actions[bot] commented 2 years ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

nvdbaranec commented 2 years ago

Still relevant.

harrism commented 2 years ago

Created https://github.com/harrism/ranger as a solution to this. Needs to be moved into libcudf.

github-actions[bot] commented 2 years ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

nvdbaranec commented 2 years ago

Still relevant.

github-actions[bot] commented 2 years ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

nvdbaranec commented 2 years ago

Still relevant.

github-actions[bot] commented 1 year ago

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

GregoryKimball commented 1 year ago

Thanks @harrism for creating the ranger repo! Do you think we are ready to kick off an integration with libcudf, or does ranger need more development first?

harrism commented 1 year ago

@GregoryKimball I have now created a PR to use ranger in libcuspatial. You guys could use this as an example if you want to do the same in libcudf. https://github.com/rapidsai/cuspatial/pull/1178

wence- commented 1 year ago

wrt attempting to find locations where this might be happening. In host code, clang and gcc will warn if you add -Wsign-conversion (not covered by -Wall -Wextra) under some circumstances. Unfortunately there is no such option for nvcc.

#include <cstdint>
int what(int upper)
{
  int i = 0; // no warning if this is a std::int64_t
  unsigned int stride = 10;
  while (i < upper) {
    i = i + stride; // clang warns for this, so does gcc
  }
  i = 0;
  while (i < upper) {
    i += stride; // gcc warns for this, clang does not.
  }
  return i;
}