Open maxrjones opened 2 years ago
I'm looking at this again, and I think the way to understand my PoV is that I want the least friction between the for batch in bgen
line and where I call my NN model (e.g., model.fit(...)
). In my situation (case 1 I think), the lack of a sample dimension forces you to create a new sample dimension, and I would argue that this is non-trivial, for a few reasons:
Additionally, as per @weiji14's comments a while back, it's a much lower cognitive load to return an array with an extra sample dimension and use np.squeeze
than it is to recreate a dimension inside the training loop.
As it so happens, there is a bug with xr.DataSet.expand_dims
. You can see the bug report here: https://github.com/pydata/xarray/issues/7456
It seems the xarray documentation is partially at fault, at least, according to them.
Apparently, there is no combination of xr.DataSet.expand_dims
and xr.DataSet.transpose
that will put the sample dimension of the DataSet in the first position. You can use axis=0
to do this for the DataSet's internal arrays, but this transposition does not affect the DataSet itself. According to https://github.com/pydata/xarray/issues/7456, this is the desired behavior (believe it or not). Supposedly, the position of the DataSet dimensions are arbitrary, but Keras disagrees :(
Thanks for your thoughts! I commented on the other thread about the expand_dims/transpose behavior, but more generally you're correct that xbatcher will need to be responsible the correct ordering of dimensions as xarray's data model is generally agnostic to axis order.
TBH I actually blame Keras/PyTorch for caring about axis order. So passé!
Thinking about this some more, the current behavior does make sense if we're not considering an ML context. Like, if you wanted to sample a bunch of patches and average each of them, a sample dimension wouldn't make sense.
I'm thinking that we could have BatchGenerator wrappers for the ML libs, and then we can append a sample dimension there. I had a look at the existing ones, but I think they don't have this.
What is your issue?
As shown in the section below, , there are a couple cases in which the batch generator will not include a
sample
dimension in the returned dataset/dataarray:input_dims
does not exceed the number of dimensions in the original dataset by more than one. In this case, the original dimension names will be retained as long asconcat_input_dims=False
.input_dims
andconcat_input_dims=True
. In this case, an extra dimension calledinput_batch
is created along with the original dimensions appended with_input
.https://github.com/xarray-contrib/xbatcher/blob/59df776734c9967a1cf38e309ab1b509152def08/xbatcher/testing.py#L126-L142
Would it be better to always include a sample dimension in the output batches, for example by renaming the excluded dimension to
sample
for case 1?