wjakob / nanobind

nanobind: tiny and efficient C++/Python bindings
BSD 3-Clause "New" or "Revised" License
2.14k stars 161 forks source link

[BUG]: The stride of an ndarray can be wrong for numpy views #487

Closed hpkfft closed 3 months ago

hpkfft commented 3 months ago

Problem description

Numpy allows a new view of an array with the same data, which can cause a reinterpretation of the bytes in memory. In particular, an array with 2N real values per row can be reinterpreted as an array with N complex values per row. [This is useful for FFTs of real data, as the result is complex-valued, and since the Fourier transform is reversible, there is not more "data content" in the output than in the input. So, numpy's views do the right thing.]

The issue arises when the stride from one row to the next is an odd number of real-valued elements. Numpy measures strides in bytes, so a view does not change the stride and it just works. In nanobind, strides are measured in units of itemsize, so an nb::ndarray cannot correctly represent a stride that is not an integral multiple of the size of the dtype.

Reproducible example code

C++

    m.def("inspect", [](const nb::ndarray<nb::ro>& a) {
            printf("Array on nb::device::%s%s (id %u) at address %p\n",
                   (a.device_type() == nb::device::cpu::value)  ? "cpu"  : "",
                   (a.device_type() == nb::device::cuda::value) ? "cuda" : "",
                   a.device_id(), a.data());
            printf("    dtype : %s%s%s%s   itemsize : %2zu\n",
                   (a.dtype() == nb::dtype<float>())    ? "float32" : "",
                   (a.dtype() == nb::dtype<double>())   ? "float64" : "",
                   (a.dtype() == nb::dtype<std::complex<float>>())
                                             ? "complex<float32>" : "",
                   (a.dtype() == nb::dtype<std::complex<double>>())
                                             ? "complex<float64>" : "",
                   a.itemsize());
            printf("    ndim : %zu\n", a.ndim());
            for (size_t i = 0; i < a.ndim(); ++i) {
                printf("    shape(%zu) : %2zu    stride(%zu) : %2zd\n",
                              i,  a.shape(i),       i,  a.stride(i));
            }
        });

Python (interactive session):

>>> a = np.array([[1, 2, 3, 4, 5, 6, np.NAN],
...               [8, 0, 0, 0, 0, 0, np.NAN]], dtype=np.float32)
>>> a
array([[ 1.,  2.,  3.,  4.,  5.,  6., nan],
       [ 8.,  0.,  0.,  0.,  0.,  0., nan]], dtype=float32)
>>> m.inspect(a)
Array on nb::device::cpu (id 0) at address 0x1a5ddb0
    dtype : float32   itemsize :  4
    ndim : 2
    shape(0) :  2    stride(0) :  7
    shape(1) :  7    stride(1) :  1
>>> a.strides
(28, 4)
>>> 
>>> s = a[:, 0:6]  # slice
>>> s
array([[1., 2., 3., 4., 5., 6.],
       [8., 0., 0., 0., 0., 0.]], dtype=float32)
>>> m.inspect(s)
Array on nb::device::cpu (id 0) at address 0x1a5ddb0
    dtype : float32   itemsize :  4
    ndim : 2
    shape(0) :  2    stride(0) :  7
    shape(1) :  6    stride(1) :  1
>>> s.strides
(28, 4)
>>> 
>>> v = s.view(np.complex64)
>>> v
array([[1.+2.j, 3.+4.j, 5.+6.j],
       [8.+0.j, 0.+0.j, 0.+0.j]], dtype=complex64)
>>> m.inspect(v)
Array on nb::device::cpu (id 0) at address 0x1a5ddb0
    dtype : complex<float32>   itemsize :  8
    ndim : 2
    shape(0) :  2    stride(0) :  3
    shape(1) :  3    stride(1) :  1
>>> v.strides
(28, 8)
hpkfft commented 3 months ago

I'd like to suggest that nb::ndarray store strides with finer granularity. Two ideas immediately come to mind:

There may be some plausible justification for the latter option if staying closer to the current behavior is desired.

static_assert(sizeof(uint32_t) == 4); static_assert(alignof(uint32_t) == 4);

static_assert(sizeof(float) == 4); static_assert(alignof(float) == 4);

static_assert(sizeof(std::complex) == 8); static_assert(alignof(std::complex) == 4);

wjakob commented 3 months ago

AFAIK this is not possible because nanobind centers around DLPack. The ability for nb::ndarray to talk to NumPy via the buffer protocol is really just there to support older NumPy versions, since DLPack is still a relatively new feature.

In any case, for DLPack, all of these quantities are relative to the itemsize. So I think this request is not compatible with the design of the library.

hpkfft commented 3 months ago

Yes, I see.

>>> import numpy as np
>>> a = np.array([[1, 2, 3, 4, 5, 6, np.NAN],
...               [8, 0, 0, 0, 0, 0, np.NAN]], dtype=np.float32)
>>> s = a[:, 0:6]  # slice
>>> v = s.view(np.complex64)
>>> v.__dlpack__()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
BufferError: DLPack only supports strides which are a multiple of itemsize.

Would it be good for nanobind to throw an exception in the buffer protocol path if strides are not a multiple of itemsize? This would avoid silent data corruption if the stride (in bytes) cannot be correctly converted to itemsize units.

If you like, I can work on a PR for this. I'm a hardware/assembly/C++ guy who is new to python, so your careful code review and any suggestions you have would be welcome.

wjakob commented 3 months ago

Would it be good for nanobind to throw an exception in the buffer protocol path if strides are not a multiple of itemsize? This would avoid silent data corruption if the stride (in bytes) cannot be correctly converted to itemsize units.

Absolutely—if this currently leads to corruption, then it should be fixed. I'm happy to review a PR if you make one.

wjakob commented 3 months ago

Closed via #489