rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

indexing merged dim #85

Closed vieting closed 2 years ago

vieting commented 2 years ago

Add a test case which demonstrates the issue #84.

vieting commented 2 years ago

@albertz it would be great if you could have a look at this.

albertz commented 2 years ago

One problem seems to be the implementation of MergeDims (here):

In case batch axis is involved, this uses RETURNN FlattenBatchLayer.

This is wrong. Or at least inconsistent to what PyTorch does. Why do we even need this separate logic? The RETURNN MergeDimsLayer should also handle the case with the batch axis, I think. If not, we should fix that. But using FlattenBatchLayer does sth different.

albertz commented 2 years ago

The test case is problematic. You should not use n_batch, n_time inside the function model_func. You should redefine it there like:

n_batch, n_time, n_in = y.shape

This is important that we keep the batch dim properly.

vieting commented 2 years ago

The test case is problematic. You should not use n_batch, n_time inside the function model_func. You should redefine it there like: n_batch, n_time, n_in = y.shape This is important that we keep the batch dim properly.

Yes, done.

albertz commented 2 years ago

I wonder if the reshape(n_b, n_t * n_index) will be handled correctly, i.e. that it gets back the batch dim. This would be via the Unflatten module. Maybe we need to extend the logic here.

albertz commented 2 years ago

Why are the tests not running now?

albertz commented 2 years ago

Btw, you should rebase to resolve the conflicts.

vieting commented 2 years ago

One problem seems to be the implementation of MergeDims (here): [...] This is wrong. Or at least inconsistent to what PyTorch does. Why do we even need this separate logic? The RETURNN MergeDimsLayer should also handle the case with the batch axis, I think. If not, we should fix that. But using FlattenBatchLayer does sth different.

Do you mean just removing that special case from MergeDims like what I just commited?

albertz commented 2 years ago

One problem seems to be the implementation of MergeDims (here): [...] using FlattenBatchLayer [...]

Do you mean just removing that special case from MergeDims like what I just commited?

Yes. Not sure if that breaks sth else. Or why we even did that in the first place.

But this is required here. Esp the indices (via arange(n_b * n_t)) index into the merged B*T axis. This must be merged and not flattened.

albertz commented 2 years ago

I wonder if the reshape(n_b, n_t * n_index) will be handled correctly, i.e. that it gets back the batch dim. This would be via the Unflatten module. Maybe we need to extend the logic here.

Yea, it is not correct:

*** root/'Unflatten' layer dict: {'class': 'split_dims', 'from': 'Cat', 'axis': 'F', 'dims': [Batch(3), 10]}
*** root/'Unflatten' SplitDimsLayer output: ['Unflatten_split_dims0'(3),F|'Unflatten_split_dims1'(10)]

It should not use Batch(3) here for dims but batch_dim instead. Also, the 10 is wrong. It needs sth similar to Batch which references the dynamic spatial axis or more specifically the dim tag (for T).

albertz commented 2 years ago

Also this is already wrong:

*** root/'Range' layer dict: {'class': 'range', 'limit': 15, 'start': 0, 'delta': 1, 'dtype': None, 'sparse': False}
*** root/'Range' RangeLayer output: [F|'Range:range'(15)]

This is more tricky to fix... We need a more generic Range module implementation. We cannot wrap RangeLayer in all cases.

albertz commented 2 years ago

For reference, in RETURNN, this works now:

"flat": {"class": "flatten_batch", "from": "data"},
"length_flat": {"class": "length", "from": "flat", "axis": batch_dim},
"indices_flat": {"class": "rand_int", "shape": (batch_dim, time_dim, SpatialDim("other-spatial", 7)), "minval": 0, "maxval": "length_flat", "seed": 42},
"output": {"class": "gather", "from": "flat", "axis": batch_dim, "position": "indices_flat"},

(Via https://github.com/rwth-i6/returnn/pull/910)

albertz commented 2 years ago

Ok, now that #87 is merged, can you rebase, and check if it maybe works already? If not, why does it not work yet?

vieting commented 2 years ago

I think it does not yet work because we have idx = torch.arange(n_b * n_t) and when multiplying SizeValues, the originating_tensor is lost.

albertz commented 2 years ago

Rebase anyway and let's see. I'm not sure.

vieting commented 2 years ago

Oh sorry, I accidentally pushed it to my fork before.

vieting commented 2 years ago

If self.originating_tensor is None in as_tensor, we could probably check if self.merged_dims contains some dims and if so, compute the output based on them.

albertz commented 2 years ago

If self.originating_tensor is None in as_tensor, we could probably check if self.merged_dims contains some dims and if so, compute the output based on them.

Yes, this would be one solution, as a fallback.

Sth like:

def as_tensor():
  if self.originating_tensor is None and self.merged_dims:
    return numpy.prod([d.as_tensor() if d.dim_tag.dimension is None else int(d) for d in self.merged_dims])
  assert self.originating_tensor is not None
  ...
vieting commented 2 years ago

If self.originating_tensor is None in as_tensor, we could probably check if self.merged_dims contains some dims and if so, compute the output based on them.

Yes, this would be one solution, as a fallback.

Sth like:

def as_tensor():
  if self.originating_tensor is None and self.merged_dims:
    return numpy.prod([d.as_tensor() if d.dim_tag.dimension is None else int(d) for d in self.merged_dims])
  assert self.originating_tensor is not None
  ...

Ok this is probably more accurate than what I pushed.

vieting commented 2 years ago

Now, if we have as_tensor of a time dim, we get a batched tensor from LengthLayer. I'm not sure what the intended behavior would be in this case as we want to have a scalar. Should it always be the maximum value?

albertz commented 2 years ago

Now, if we have as_tensor of a time dim, we get a batched tensor from LengthLayer. I'm not sure what the intended behavior would be in this case as we want to have a scalar. Should it always be the maximum value?

The intended behavior is clear. This as_tensor() should describe the shape dim, i.e. be a scalar. So then we need some torch.max.

albertz commented 2 years ago

There is a bug in our Cat module. The out shape is wrong. We also need to define _get_output_shape_from_returnn there.

albertz commented 2 years ago

I did some fixes here. I think it's mostly finished, except that I also need to make another fix on RETURNN side in SplitDimsLayer.

albertz commented 2 years ago

The change in Unflatten is causing some problems. I'm working on it.

albertz commented 2 years ago

The introduction of dim tags needs some logic similar to returnn-common where we serialize them for the net dict. This is needed for the converter. I will do that later.

Edit For this PR, actually the serialization is not needed. Some fixes in the converter logic was enough for now. We need the serialization once we want to serialize/save the actual net dict / config.

albertz commented 2 years ago

We also need https://github.com/rwth-i6/returnn/pull/913 here. Edit Merged.