rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

new zeros with dynamic axes #69

Closed vieting closed 2 years ago

vieting commented 2 years ago

tensor.new_zeros() currently only works for static dims. I added a test case to demonstrate this. I think the comparison of dim tags should be fixed as in the commit, but still the axes are not matched correctly and RETURNN inserts an additional extra axis.

Small extract of the error:

AssertionError: 
Not equal to tolerance rtol=0, atol=0.0005

(shapes (3, 5, 11), (5, 3, 5, 11) mismatch)
albertz commented 2 years ago

The fix in _unify_tensor_axes_returnn_meta seems good, but I guess there is also still some other further problem.

vieting commented 2 years ago

I think when adding inputs and x, the reinterpreted time dim is not matched to the original time dim which is why RETURNN creates an additional axis.

albertz commented 2 years ago

Can you rebase this and check if it works already? I think it should work already.

albertz commented 2 years ago

Continued work in #88.