rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

Output dim tag of torch.full #114

Closed vieting closed 2 years ago

vieting commented 2 years ago

In #113, it shows that the dim tag of the output shape of torch.full is not set correctly. An example to reproduce the issue looks like this:

def test_full_out_dim_tag():
  n_batch, n_feat = 3, 5

  def model_func(wrapped_import, inputs: torch.Tensor):
    if typing.TYPE_CHECKING or not wrapped_import:
      import torch
    else:
      torch = wrapped_import("torch")
    full = torch.full((inputs.shape[0],), 42)
    arange = torch.arange(full .shape[0])
    return arange + full 

  rnd = numpy.random.RandomState(42)
  x = rnd.normal(0., 1., (n_batch, n_feat)).astype("float32")
  verify_torch_and_convert_to_returnn(model_func, inputs=x)

which fails in pytorch_to_returnn/torch/nn/modules/operator.py, line 594, in _unify_tensor_axes_returnn_meta with

    line: assert all(dim.dimension == d.dimension for d in dims_for_axis if d is not None and d.dimension != 1), (
            f"invalid input {x} axis {i} dim {dim}")
    locals:
      all = <builtin> <built-in function all>
      dim = <local> Dim{'Range:range'(3)}
      dim.dimension = <local> 3
      d = <not found>
      d.dimension = <not found>
      dims_for_axis = <local> [Dim{'Range:range'(3)}, Dim{B}]
AssertionError: invalid input <Tensor name:? tensor:('static_dim'(3)(3),) returnn_data:'Cast_output' [F|'Range:range'(3)] axes id> axis 0 dim Dim{'Range:range'(3)}

(I'll add a full stack trace and log later).

We can see that the dim tag in the output shape of torch.full is wrong, but in returnn_data it's still correct. When creating a range based on that, the wrong dim tag of the output shape is also used on RETURNN side, which will lead to the error when trying to add both tensors.

      arange = <local> <Tensor name:? tensor:('static_dim'(3)(3),) returnn_data:'Range_output' [F|'Range:range'(3)] axes id>
      full = <local> <Tensor name:? tensor:('static_dim'(3)(3),) returnn_data:'FullStatic_const' [B] axes id>