rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

issue when indexing merged dim #84

Closed vieting closed 2 years ago

vieting commented 2 years ago

As shown in the test case in #85, there is currently some bug when we index a dim with flattened batch and time.

Code:

y = inputs.view(-1, n_in)
idx = torch.arange(n_batch * n_time)
idx = torch.cat([idx] * n_index).reshape(n_batch, n_time * n_index)
x = y[idx.view(-1)]
x = x.view(n_batch, n_time, n_index, n_in)
x = x.permute(2, 0, 1, 3)
out = torch.cat([inputs.unsqueeze(0), x])

Log output (via):

``` 2022-01-18T10:48:42.2941933Z >>> Running with wrapped Torch import, wrapping replacement for PyTorch... 2022-01-18T10:48:42.2942434Z RETURNN input: Data{'data', [B,T|'time:var:extern_data:data'[B],F|F'feature:data'(7)]} 2022-01-18T10:48:42.2943002Z *** root/'Flatten' layer dict: {'class': 'flatten_batch', 'from': 'data', 'axis': 'T', 'batch_major': True} 2022-01-18T10:48:42.2943565Z *** root/'Flatten' FlattenBatchLayer output: [B&Packed{'time:var:extern_data:data'},F|F'feature:data'(7)] 2022-01-18T10:48:42.2944131Z *** root/'Range' layer dict: {'class': 'range', 'limit': 15, 'start': 0, 'delta': 1, 'dtype': None, 'sparse': False} 2022-01-18T10:48:42.2944574Z *** root/'Range' RangeLayer output: [F|'Range:range'(15)] 2022-01-18T10:48:42.2945006Z *** root/'Cat' layer dict: {'class': 'concat', 'from': [('Range', 'F'), ('Range', 'F')]} 2022-01-18T10:48:42.2945425Z *** root/'Cat' ConcatLayer output: [F|'2*Range:range'(30)] 2022-01-18T10:48:42.2945889Z *** root/'Unflatten' layer dict: {'class': 'split_dims', 'from': 'Cat', 'axis': 'F', 'dims': [3, 10]} 2022-01-18T10:48:42.2946418Z *** root/'Unflatten' SplitDimsLayer output: ['Unflatten_split_dims0'(3),F|'Unflatten_split_dims1'(10)] 2022-01-18T10:48:42.2946976Z *** root/'Flatten_1' layer dict: {'class': 'merge_dims', 'from': 'Unflatten', 'axes': ['static:0', 'F'], 'keep_order': True} 2022-01-18T10:48:42.2947506Z *** root/'Flatten_1' MergeDimsLayer output: [F|'Unflatten_split_dims0*Unflatten_split_dims1'(30)] 2022-01-18T10:48:42.2948064Z *** root/'GatherTensor' layer dict: {'class': 'gather', 'from': 'Flatten', 'axis': 'B', 'position': 'Flatten_1'} 2022-01-18T10:48:42.2948646Z *** root/'GatherTensor' GatherLayer output: ['Unflatten_split_dims0*Unflatten_split_dims1'(30),F|F'feature:data'(7)] 2022-01-18T10:48:42.2949218Z *** root/'Unflatten_1' layer dict: {'class': 'split_dims', 'from': 'GatherTensor', 'axis': 'static:0', 'dims': [3, 5, 2]} 2022-01-18T10:48:42.2949865Z *** root/'Unflatten_1' SplitDimsLayer output: ['Unflatten_1_split_dims0'(3),'Unflatten_1_split_dims1'(5),'Unflatten_1_split_dims2'(2),F|F'feature:data'(7)] 2022-01-18T10:48:42.2950388Z *** root/'Transpose' layer dict: {'class': 'copy', 'from': 'Unflatten_1'} 2022-01-18T10:48:42.2950988Z *** root/'Transpose' CopyLayer output: ['Unflatten_1_split_dims0'(3),'Unflatten_1_split_dims1'(5),'Unflatten_1_split_dims2'(2),F|F'feature:data'(7)] 2022-01-18T10:48:42.2951549Z *** root/'Unflatten_2' layer dict: {'class': 'split_dims', 'from': 'data', 'axis': 'B', 'dims': [1, -1]} 2022-01-18T10:48:42.2952130Z *** root/'Unflatten_2' SplitDimsLayer output: ['Unflatten_2_split_dims0'(1),B,T|'time:var:extern_data:data'[B],F|F'feature:data'(7)] 2022-01-18T10:48:42.2952705Z *** root/'Cat_1' layer dict: {'class': 'concat', 'from': [('Unflatten_2', 'static:0'), ('Transpose', 'static:2')]} 2022-01-18T10:48:42.2953182Z Exception creating layer root/'Cat_1' of class ConcatLayer with opts: 2022-01-18T10:48:42.2953492Z {'_name': 'Cat_1', 2022-01-18T10:48:42.2953805Z '_network': , 2022-01-18T10:48:42.2954086Z 'name': 'Cat_1', 2022-01-18T10:48:42.2954398Z 'network': , 2022-01-18T10:48:42.2955025Z 'sources': [(, 2022-01-18T10:48:42.2955441Z 'static:0'), 2022-01-18T10:48:42.2955989Z (, 2022-01-18T10:48:42.2956407Z 'static:2')]} ```

Then exception in:

  File "/home/runner/.local/lib/python3.8/site-packages/returnn/tf/util/data.py", line 330, in Dim.get_for_batch_ctx
    line: dyn_size_ext = base_can_use_in_ctx.dyn_size_ext.copy_extend_batch(batch)
    locals:
      dyn_size_ext = <local> None
      base_can_use_in_ctx = <local> Dim{'time:var:extern_data:data'[?]}
      base_can_use_in_ctx.dyn_size_ext = <local> None
      base_can_use_in_ctx.dyn_size_ext.copy_extend_batch = <local> !AttributeError: 'NoneType' object has no attribute 'copy_extend_batch'
      batch = <local> BatchInfo{B, Packed{'time:var:extern_data:data'}}
  File "/home/runner/.local/lib/python3.8/site-packages/returnn/tf/util/data.py", line 3441, in Data.copy_extend_batch
    line: new_dims = ensure_list_of_type(batch.virtual_dims, BatchInfo.FixedDim)
    locals:
      new_dims = <not found>
      ensure_list_of_type = <local> <function ensure_list_of_type at 0x7fee81b4f820>
      batch = <local> BatchInfo{B, Packed{'time:var:extern_data:data'}}
      batch.virtual_dims = <local> [GlobalBatchDim{B}, PackedDim{Packed{'time:var:extern_data:data'}}]

'Unflatten_2': SplitDimsLayer output: ['Unflatten_2_split_dims0'(1),B,T|'time:var:extern_data:data'[B],F|F'feature:data'(7)]

'Transpose' CopyLayer output: ['Unflatten_1_split_dims0'(3),'Unflatten_1_split_dims1'(5),'Unflatten_1_split_dims2'(2),F|F'feature:data'(7)]

'Cat_1' layer dict: {'class': 'concat', 'from': [('Unflatten_2', 'static:0'), ('Transpose', 'static:2')]}

albertz commented 2 years ago

This should be fixed via #85.