rapidsai / raft

RAFT contains fundamental widely-used algorithms and primitives for machine learning and information retrieval. The algorithms are CUDA-accelerated and form building blocks for more easily writing high performance applications.
https://docs.rapids.ai/api/raft/stable/
Apache License 2.0
784 stars 195 forks source link

[ENH] mdspan-ify iterator-based API for `linalg::reduce_rows_by_keys` #925

Open Nyrio opened 2 years ago

Nyrio commented 2 years ago

Following #909, the API for linalg::reduce_rows_by_keys can take custom iterators, in order to save intermediate steps in ann_kmeans_balanced. But raw-pointer/iterator APIs are being deprecated in favor of mdspan.

We should provide appropriate helpers and types for iterator-based mdspan, and change this API accordingly. See discussions on the aforementioned PR.

Nyrio commented 2 years ago

cc @cjnolet @mhoemmen who were part of the discussion about this.

mhoemmen commented 2 years ago

Hi @Nyrio and @cjnolet ! I can think of two likely meanings for "iterator-based mdspan."

  1. "An mdspan whose data_handle_type is a random access iterator"
  2. "Implement begin and end for generic mdspan"

Given the context, I'm guessing that you mean (1). You'll need at least forward iterators (because mdspan depends on the multipass guarantee). You can imitate default_accessor to write a generic iterator accessor. Use std::iterator_traits to get the iterator's value type, and std::advance (if you want to support forward iterators that are not also random access iterators) in the access and offset member functions. For example:

using reference = typename std::iterator_traits<Iterator>::reference;
reference access(Iterator iter, std::size_t index) const {
  return *std::advance(iter, index);
}

For defining the accessor's required type aliases (see [mdspan.accessor.reqmts]), it's a bit tricky to get the element_type (it's const value_type if the iterator is an iterator-of-const, else value_type). If you can figure out how to do that, say via type alias template<std::forward_iterator Iterator> using iterator_element_t = /* ... */, then you'll find that useful below.

The question then becomes how to apply __host__ or __device__ (or both, for managed or pinned allocations) to access, so that you get the type safety benefits of RAFT's host_accessor and device_accessor. (For this case, you might also need to require random access iterators and replace std::advance with iter[index], unless libcu++ has a std::advance equivalent blessed by __host__ __device__.) If you have different custom iterator types for host vs. device access, you can map from the trait to whether the access function has __host__, __device__, or both (e.g., via specialization of a base class).

At this point, you can convert your iterator ranges concisely into rank-1 mdspan.

template<std::forward_iterator Iterator, std::integral IndexType, std::size_t Extent = dynamic_extent>
using iterator_mdspan_t = mdspan<
  iterator_element_t<Iterator>, // (possibly const value_type)
  extents<IndexType, Extent>,
  iterator_accessor<Iterator>>;

// For pre-C++20, remove Sentinel template parameter,
// make the second function parameter an Iterator, and
// replace std::ranges::distance with std::distance.
template<class Iterator, class Sentinel>
auto range_to_mdspan(Iterator begin, Sentinel end) {
  auto distance = std::ranges::distance(begin, end);
  using distance_type = decltype(distance);
  return iterator_mdspan_t<Iterator, distance_type, dynamic_extent>{begin, distance};
}