Closed vieting closed 2 years ago
@albertz it would be great if you could have a look at this.
One problem seems to be the implementation of MergeDims
(here):
In case batch axis is involved, this uses RETURNN FlattenBatchLayer.
This is wrong. Or at least inconsistent to what PyTorch does. Why do we even need this separate logic? The RETURNN MergeDimsLayer
should also handle the case with the batch axis, I think. If not, we should fix that. But using FlattenBatchLayer
does sth different.
The test case is problematic. You should not use n_batch
, n_time
inside the function model_func
. You should redefine it there like:
n_batch, n_time, n_in = y.shape
This is important that we keep the batch dim properly.
The test case is problematic. You should not use n_batch, n_time inside the function model_func. You should redefine it there like: n_batch, n_time, n_in = y.shape This is important that we keep the batch dim properly.
Yes, done.
I wonder if the reshape(n_b, n_t * n_index)
will be handled correctly, i.e. that it gets back the batch dim. This would be via the Unflatten
module. Maybe we need to extend the logic here.
Why are the tests not running now?
Btw, you should rebase to resolve the conflicts.
One problem seems to be the implementation of MergeDims (here): [...] This is wrong. Or at least inconsistent to what PyTorch does. Why do we even need this separate logic? The RETURNN MergeDimsLayer should also handle the case with the batch axis, I think. If not, we should fix that. But using FlattenBatchLayer does sth different.
Do you mean just removing that special case from MergeDims
like what I just commited?
One problem seems to be the implementation of MergeDims (here): [...] using FlattenBatchLayer [...]
Do you mean just removing that special case from
MergeDims
like what I just commited?
Yes. Not sure if that breaks sth else. Or why we even did that in the first place.
But this is required here. Esp the indices (via arange(n_b * n_t)
) index into the merged B*T axis. This must be merged and not flattened.
I wonder if the
reshape(n_b, n_t * n_index)
will be handled correctly, i.e. that it gets back the batch dim. This would be via theUnflatten
module. Maybe we need to extend the logic here.
Yea, it is not correct:
*** root/'Unflatten' layer dict: {'class': 'split_dims', 'from': 'Cat', 'axis': 'F', 'dims': [Batch(3), 10]}
*** root/'Unflatten' SplitDimsLayer output: ['Unflatten_split_dims0'(3),F|'Unflatten_split_dims1'(10)]
It should not use Batch(3)
here for dims
but batch_dim
instead.
Also, the 10
is wrong. It needs sth similar to Batch
which references the dynamic spatial axis or more specifically the dim tag (for T).
Also this is already wrong:
*** root/'Range' layer dict: {'class': 'range', 'limit': 15, 'start': 0, 'delta': 1, 'dtype': None, 'sparse': False}
*** root/'Range' RangeLayer output: [F|'Range:range'(15)]
This is more tricky to fix...
We need a more generic Range
module implementation. We cannot wrap RangeLayer
in all cases.
For reference, in RETURNN, this works now:
"flat": {"class": "flatten_batch", "from": "data"},
"length_flat": {"class": "length", "from": "flat", "axis": batch_dim},
"indices_flat": {"class": "rand_int", "shape": (batch_dim, time_dim, SpatialDim("other-spatial", 7)), "minval": 0, "maxval": "length_flat", "seed": 42},
"output": {"class": "gather", "from": "flat", "axis": batch_dim, "position": "indices_flat"},
Ok, now that #87 is merged, can you rebase, and check if it maybe works already? If not, why does it not work yet?
I think it does not yet work because we have idx = torch.arange(n_b * n_t)
and when multiplying SizeValue
s, the originating_tensor
is lost.
Rebase anyway and let's see. I'm not sure.
Oh sorry, I accidentally pushed it to my fork before.
If self.originating_tensor
is None
in as_tensor
, we could probably check if self.merged_dims
contains some dims and if so, compute the output based on them.
If
self.originating_tensor
isNone
inas_tensor
, we could probably check ifself.merged_dims
contains some dims and if so, compute the output based on them.
Yes, this would be one solution, as a fallback.
Sth like:
def as_tensor():
if self.originating_tensor is None and self.merged_dims:
return numpy.prod([d.as_tensor() if d.dim_tag.dimension is None else int(d) for d in self.merged_dims])
assert self.originating_tensor is not None
...
If
self.originating_tensor
isNone
inas_tensor
, we could probably check ifself.merged_dims
contains some dims and if so, compute the output based on them.Yes, this would be one solution, as a fallback.
Sth like:
def as_tensor(): if self.originating_tensor is None and self.merged_dims: return numpy.prod([d.as_tensor() if d.dim_tag.dimension is None else int(d) for d in self.merged_dims]) assert self.originating_tensor is not None ...
Ok this is probably more accurate than what I pushed.
Now, if we have as_tensor
of a time dim, we get a batched tensor from LengthLayer
. I'm not sure what the intended behavior would be in this case as we want to have a scalar. Should it always be the maximum value?
Now, if we have
as_tensor
of a time dim, we get a batched tensor fromLengthLayer
. I'm not sure what the intended behavior would be in this case as we want to have a scalar. Should it always be the maximum value?
The intended behavior is clear. This as_tensor()
should describe the shape dim, i.e. be a scalar. So then we need some torch.max
.
There is a bug in our Cat
module. The out shape is wrong. We also need to define _get_output_shape_from_returnn
there.
I did some fixes here. I think it's mostly finished, except that I also need to make another fix on RETURNN side in SplitDimsLayer
.
The change in Unflatten
is causing some problems. I'm working on it.
The introduction of dim tags needs some logic similar to returnn-common where we serialize them for the net dict. This is needed for the converter. I will do that later.
Edit For this PR, actually the serialization is not needed. Some fixes in the converter logic was enough for now. We need the serialization once we want to serialize/save the actual net dict / config.
We also need https://github.com/rwth-i6/returnn/pull/913 here. Edit Merged.
Add a test case which demonstrates the issue #84.