cupy / cupy

NumPy & SciPy for GPU
https://cupy.dev
MIT License
9.36k stars 842 forks source link

split won't accept 0d array as the section count #3512

Open peterbell10 opened 4 years ago

peterbell10 commented 4 years ago

In NumPy, np.split will accept a 0 dimension array as the number of arrays to split into:

In [1]: import numpy as np
   ...: np.split(np.arange(4).reshape(2,2), np.array(2), 0)                                 
Out[1]: [array([[0, 1]]), array([[2, 3]])]

In CuPy, the same expression fails:

In [1]: import cupy
   ...: cupy.split(cupy.arange(4).reshape(2,2), cupy.array(2), 0)                           
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-17-13106a5bb5e4> in <module>
----> 1 cupy.split(cupy.arange(4).reshape(2,2), cupy.array(2), 0)

~/miniconda3/envs/uarray-dev/lib/python3.7/site-packages/cupy/manipulation/split.py in split(ary, indices_or_sections, axis)
     76                 'If you want to split the array into non-equally-sized '
     77                 'arrays, use array_split instead.')
---> 78     return array_split(ary, indices_or_sections, axis)
     79 
     80 

~/miniconda3/envs/uarray-dev/lib/python3.7/site-packages/cupy/manipulation/split.py in array_split(ary, indices_or_sections, axis)
     14 
     15     """
---> 16     return core.array_split(ary, indices_or_sections, axis)
     17 
     18 

cupy/core/_routines_manipulation.pyx in cupy.core._routines_manipulation.array_split()

cupy/core/_routines_manipulation.pyx in cupy.core._routines_manipulation.array_split()

cupy/core/core.pyx in cupy.core.core.ndarray.__iter__()

TypeError: iteration over a 0-d array
kmaehashi commented 4 years ago

This is currently by design. indices_or_sections in cupy.split must reside on host memory, as it is needed on host side.

https://docs-cupy.chainer.org/en/stable/reference/generated/cupy.split.html

Note that the sequence on the device memory is not allowed.

I think we need to improve documentation that 0-dim arrays on the device memory are not allowed as well.