In #113, it shows that the dim tag of the output shape of torch.full is not set correctly. An example to reproduce the issue looks like this:
def test_full_out_dim_tag():
n_batch, n_feat = 3, 5
def model_func(wrapped_import, inputs: torch.Tensor):
if typing.TYPE_CHECKING or not wrapped_import:
import torch
else:
torch = wrapped_import("torch")
full = torch.full((inputs.shape[0],), 42)
arange = torch.arange(full .shape[0])
return arange + full
rnd = numpy.random.RandomState(42)
x = rnd.normal(0., 1., (n_batch, n_feat)).astype("float32")
verify_torch_and_convert_to_returnn(model_func, inputs=x)
which fails in pytorch_to_returnn/torch/nn/modules/operator.py, line 594, in _unify_tensor_axes_returnn_meta with
line: assert all(dim.dimension == d.dimension for d in dims_for_axis if d is not None and d.dimension != 1), (
f"invalid input {x} axis {i} dim {dim}")
locals:
all = <builtin> <built-in function all>
dim = <local> Dim{'Range:range'(3)}
dim.dimension = <local> 3
d = <not found>
d.dimension = <not found>
dims_for_axis = <local> [Dim{'Range:range'(3)}, Dim{B}]
AssertionError: invalid input <Tensor name:? tensor:('static_dim'(3)(3),) returnn_data:'Cast_output' [F|'Range:range'(3)] axes id> axis 0 dim Dim{'Range:range'(3)}
(I'll add a full stack trace and log later).
We can see that the dim tag in the output shape of torch.full is wrong, but in returnn_data it's still correct. When creating a range based on that, the wrong dim tag of the output shape is also used on RETURNN side, which will lead to the error when trying to add both tensors.
In #113, it shows that the dim tag of the output shape of
torch.full
is not set correctly. An example to reproduce the issue looks like this:which fails in
pytorch_to_returnn/torch/nn/modules/operator.py
, line 594, in_unify_tensor_axes_returnn_meta
with(I'll add a full stack trace and log later).
We can see that the dim tag in the output shape of
torch.full
is wrong, but inreturnn_data
it's still correct. When creating a range based on that, the wrong dim tag of the output shape is also used on RETURNN side, which will lead to the error when trying to add both tensors.