pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.57k stars 1.07k forks source link

Specialization of `set_dims` for new dimensions of length 1 #9462

Open hmaarrfk opened 2 weeks ago

hmaarrfk commented 2 weeks ago

Is your feature request related to a problem?

Our duck arrays got caught by duck_array_ops.broadcast_to https://github.com/pydata/xarray/blob/0af197985840a715c3566b6bdb5f355b21224e92/xarray/core/variable.py#L1382

really boiling down to np.broadcast_to since we don't implement it quite yet.

In the lines above, there seems to be a desire to keep the array "writable" (that won't work in our arrays, they are read-only anyway), but there is an other case where the underlying array stays writable, and that is when new dimensions of "1" get added.

Describe the solution you'd like

New dimensions of 1 to be inserted via slicing np.newaxis

I thinke the code could be refactored to be

       new_dims = tuple(d for d in dim if d not in self_dims)
       expanded_dims = new_dims + self.dims

       if self.dims == expanded_dims:
            # don't use broadcast_to unless necessary so the result remains
            # writeable if possible
            expanded_data = self.data
       elif shape is None or all(shape[d] == 1for d in new_dims):
            # don't use broadcast_to unless necessary so the result remains
            # writeable if possible
            # especially when creating dimensions of size 1
            indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
            expanded_data = self.data[indexer]
       else:
            dims_map = dict(zip(dim, shape))
            tmp_shape = tuple(dims_map[d] for d in expanded_dims)
            expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)

effectively just flipping the "else" and and the "elif"

Describe alternatives you've considered

Manually broadcasting the DataArrays prior to the call to xr.expand_dims

Additional context

import numpy as np
a = np.zeros((3, 2))
a[np.newaxis, :, np.newaxis, :, np.newaxis, ...].flags.writeable
assert a[np.newaxis, :, np.newaxis, :, np.newaxis, ...].flags.writeable
dcherian commented 4 days ago

In the lines above, there seems to be a desire to keep the array "writable" (that won't work in our arrays, they are read-only anyway)

This desire is because we call broadcast under the hood in many places, and it can get quite surprising when arrays are suddenly read-only.

indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)

Looks like this is supported by array API indexing rules so your proposal sounds good to me.

Each None in the selection tuple must expand the dimensions of the resulting selection by one dimension of size 1. The position of the added dimension must be the same as the position of None in the selection tuple.

hmaarrfk commented 4 days ago

ok thanks, i'll try to think of a test to add and to do this PR. We ultimately work around issues like this, but I appreciate the thought put into answering all the questions we have.