thouis / numpy-trac-migration

numpy Trac to github issues migration
2 stars 3 forks source link

array_split does not preserve dtype of original array when subarray length is 0. (Trac #2156) #5948

Open numpy-gitbot opened 11 years ago

numpy-gitbot commented 11 years ago

Original ticket http://projects.scipy.org/numpy/ticket/2156 on 2012-06-12 by atmention:fengy-research, assigned to unknown.

The original dtype is u4, but the output becomes f8.

In [4]: numpy.array_split ( zeros(shape=4, dtype='u4'), 10)
Out[4]: 
[array([0], dtype=uint32),
 array([0], dtype=uint32),
 array([0], dtype=uint32),
 array([0], dtype=uint32),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64)]

The following code seems to be causing the problem:

def _replace_zero_by_x_arrays(sub_arys):
    for i in range(len(sub_arys)):
        if len(_nx.shape(sub_arys[i])) == 0:
            sub_arys[i] = _nx.array([])
        elif _nx.sometrue(_nx.equal(_nx.shape(sub_arys[i]),0)):
            sub_arys[i] = _nx.array([])
    return sub_arys

shouldn't those _nx.array([]) be replaced by _nx.array([], dtype=sub_arys[i].dtype)?

numpy-gitbot commented 11 years ago

Attachment added by atmention:fengy-research on 2012-06-12: patch