Open cloudhan opened 3 days ago
@thakkarV @ccecka Pls take a look.
There's no pointer or tensor type.
Tensor a = make_tensor<float>(make_shape(_4{}, _1{})); // Owning tensor
Tensor b = make_tensor(my_float_ptr, make_shape(_4{}, _1{})); // Non-owning tensor
Your version should not compile, but could have a better error message.
@ccecka I updated the issue code sample. The problem is coalesce(tensor)
will dispatch to coalesce(Shape s)
when tensor
is defined out side of function bad
a.k.a, the error message
"auto cute::coalesce(const Shape &) [with Shape=cute::Tensor<cute::ArrayEngine<...
Thanks, that's actually a pretty serious bug and affects all functions that have an overload for Tensor&&
and a generic overload (typically for IntTuple const&
). These appear to include
flatten
coalesce
filter_zeros
take
Previously, this has been solved by renaming functions -- see group_modes
vs group
and filter
versus filter_tuple
. Rather than that, we should probably just do the rref three-step dance for those functions
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout> const& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout>& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr auto
flatten(Tensor<Engine,Layout>&& tensor) {
return make_tensor(tensor.data(), flatten(tensor.layout()));
}
I'll make sure this gets fixed as soon as possible.
Describe the bug
The following code does not compile with latest (637b15906358191cb4238af419d408a65819d7ec)
Expected behavior
It compiles.