Open Nyrio opened 2 years ago
cc @cjnolet @mhoemmen who were part of the discussion about this.
Hi @Nyrio and @cjnolet ! I can think of two likely meanings for "iterator-based mdspan."
data_handle_type
is a random access iterator"begin
and end
for generic mdspan"Given the context, I'm guessing that you mean (1). You'll need at least forward iterators (because mdspan depends on the multipass guarantee). You can imitate default_accessor
to write a generic iterator accessor. Use std::iterator_traits
to get the iterator's value type, and std::advance
(if you want to support forward iterators that are not also random access iterators) in the access
and offset
member functions. For example:
using reference = typename std::iterator_traits<Iterator>::reference;
reference access(Iterator iter, std::size_t index) const {
return *std::advance(iter, index);
}
For defining the accessor's required type aliases (see [mdspan.accessor.reqmts]), it's a bit tricky to get the element_type
(it's const value_type
if the iterator is an iterator-of-const, else value_type
). If you can figure out how to do that, say via type alias template<std::forward_iterator Iterator> using iterator_element_t = /* ... */
, then you'll find that useful below.
The question then becomes how to apply __host__
or __device__
(or both, for managed or pinned allocations) to access
, so that you get the type safety benefits of RAFT's host_accessor
and device_accessor
. (For this case, you might also need to require random access iterators and replace std::advance
with iter[index]
, unless libcu++ has a std::advance
equivalent blessed by __host__ __device__
.) If you have different custom iterator types for host vs. device access, you can map from the trait to whether the access
function has __host__
, __device__
, or both (e.g., via specialization of a base class).
At this point, you can convert your iterator ranges concisely into rank-1 mdspan.
template<std::forward_iterator Iterator, std::integral IndexType, std::size_t Extent = dynamic_extent>
using iterator_mdspan_t = mdspan<
iterator_element_t<Iterator>, // (possibly const value_type)
extents<IndexType, Extent>,
iterator_accessor<Iterator>>;
// For pre-C++20, remove Sentinel template parameter,
// make the second function parameter an Iterator, and
// replace std::ranges::distance with std::distance.
template<class Iterator, class Sentinel>
auto range_to_mdspan(Iterator begin, Sentinel end) {
auto distance = std::ranges::distance(begin, end);
using distance_type = decltype(distance);
return iterator_mdspan_t<Iterator, distance_type, dynamic_extent>{begin, distance};
}
Following #909, the API for
linalg::reduce_rows_by_keys
can take custom iterators, in order to save intermediate steps inann_kmeans_balanced
. But raw-pointer/iterator APIs are being deprecated in favor of mdspan.We should provide appropriate helpers and types for iterator-based mdspan, and change this API accordingly. See discussions on the aforementioned PR.