rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

Add cross entropy #113

Closed vieting closed 2 years ago

vieting commented 2 years ago

I added torch.nn.functional.cross_entropy and some small things that were necessary for that. It is used in the contrastive loss of wav2vec. I added an analogous usage in test_contrastive_loss which is not yet working fully.

vieting commented 2 years ago

The problem is in

   File "/home/runner/work/pytorch-to-returnn/pytorch-to-returnn/pytorch_to_returnn/torch/nn/functional.py", line 842, in cross_entropy
    line: idcs = target + torch.arange(target.shape[0]) * target.shape[0]
    locals:
      idcs = <not found>
      target = <local> <Tensor name:? tensor:('static_dim'(15)(15),) returnn_data:'FullStatic_const' [B] axes id>
      torch = <global> <module 'pytorch_to_returnn.torch' from '/home/runner/work/pytorch-to-returnn/pytorch-to-returnn/pytorch_to_returnn/torch/__init__.py'>
      torch.arange = <global> <function arange at 0x7f0e105b0af0>
      target.shape = <local> ('static_dim'(15)(15),), len = 1

For target, the dim tag in the returnn_data does not match the dim tag in the SizeValue in shape. This is because in FullStatic.make_output_tensor_from_returnn, the output tensor is created via from_numpy and the size has only ints, to it is not recognized as batch dim. Not sure though if this is the core of the problem.

albertz commented 2 years ago

Yea, this is wrong:

target = <local> <Tensor name:? tensor:('static_dim'(15)(15),) returnn_data:'FullStatic_const' [B] axes id>

It should not be a static_dim here.

I guess this is exactly what you reported in #114.

albertz commented 2 years ago

So, this is ready now?

vieting commented 2 years ago

Yes, I rebased and now this should be ready.