NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
4.83k stars 834 forks source link

[BUG] coalesce with Tensor dispatch to coalesce with Shape #1612

Open cloudhan opened 3 days ago

cloudhan commented 3 days ago

Describe the bug

The following code does not compile with latest (637b15906358191cb4238af419d408a65819d7ec)

#include "cute/tensor.hpp"
using namespace cute;

// alternative impl (different name) for coalesce(Tensor&&)
template <typename Tensor>
__host__ __device__ constexpr auto
coalesce_tensor(Tensor&& t) {
  return make_tensor(static_cast<Tensor&&>(t).data(), coalesce(t.layout()));
}

template <typename TensorEngine, typename TensorLayout>
__forceinline__ __host__ __device__ void
bad(const Tensor<TensorEngine, TensorLayout>& t) {
  auto c = coalesce(t);  // bad, dispatch to coalesce(Shape)
  // auto c = coalesce_tensor(t);  // good, there is no such coalesce_tensor(Shape)
}

__global__ void kernel() {
  auto t = make_tensor<float>(make_shape(_4{}, _1{}));
  auto c = coalesce(t);  // good, dispatch to coalesce(Tensor)
  bad(t);
}

int main() {
  kernel<<<1,1>>>();
  return 0;
}
include/cute/layout.hpp(859): error: static assertion failed
    static_assert(is_integral<Shape>::value || is_tuple<Shape>::value);
    ^
          detected during:
            instantiation of "auto cute::coalesce(const Shape &) [with Shape=cute::Tensor<cute::ArrayEngine<float, 4>, cute::Layout<cute::tuple<cute::_4, cute::_1>, cute::tuple<cute::_1, cute::C<0>>>>]" at line 14 of tmp.cu
            instantiation of "void bad(const cute::Tensor<TensorEngine, TensorLayout> &) [with TensorEngine=cute::ArrayEngine<float, 4>, TensorLayout=cute::Layout<cute::tuple<cute::_4, cute::_1>, cute::tuple<cute::_1, cute::C<0>>>]" at line 21 of tmp.cu

Expected behavior

It compiles.

cloudhan commented 3 days ago

@thakkarV @ccecka Pls take a look.

ccecka commented 3 days ago

There's no pointer or tensor type.

Tensor a = make_tensor<float>(make_shape(_4{}, _1{}));          // Owning tensor
Tensor b = make_tensor(my_float_ptr, make_shape(_4{}, _1{}));   // Non-owning tensor

Your version should not compile, but could have a better error message.

cloudhan commented 3 days ago

@ccecka I updated the issue code sample. The problem is coalesce(tensor) will dispatch to coalesce(Shape s) when tensor is defined out side of function bad

a.k.a, the error message

"auto cute::coalesce(const Shape &) [with Shape=cute::Tensor<cute::ArrayEngine<...
ccecka commented 3 days ago

Thanks, that's actually a pretty serious bug and affects all functions that have an overload for Tensor&& and a generic overload (typically for IntTuple const&). These appear to include flatten coalesce filter_zeros take

Previously, this has been solved by renaming functions -- see group_modes vs group and filter versus filter_tuple. Rather than that, we should probably just do the rref three-step dance for those functions

template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout> const& tensor) {
  return make_tensor(tensor.data(), flatten(tensor.layout()));
}

template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout>& tensor) {
  return make_tensor(tensor.data(), flatten(tensor.layout()));
}

template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout>&& tensor) {
  return make_tensor(tensor.data(), flatten(tensor.layout()));
}

I'll make sure this gets fixed as soon as possible.