rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

allow custom modules with custom RETURNN layer #129

Closed vieting closed 1 year ago

vieting commented 1 year ago

I'd like to create a module which has a torch forward function, but nevertheless is mapped to a custom RETURNN layer call. The basic idea is to have something like

class MyModule(torch.nn.Module):
  is_original_torch_module = False

  def forward(self, x):
    y = foo(x)
    return y

  def create_returnn_layer_dict(self, input):
    return {"class": "<some_layer>", "from": self._get_input_layer_name(input)}

Right now, the custom create_returnn_layer_dict is ignored and the code inside forward is wrapped. This PR changes this behavior to make sure to use the custom create_returnn_layer_dict.

vieting commented 1 year ago

I added test_dct as an example test case because that's what I was interested in. We can also simplify in general or reformat the code in dct() if you don't like it.

albertz commented 1 year ago

There are three failing tests. Did you check them?

vieting commented 1 year ago

There are three failing tests. Did you check them?

Yes, they are the same as in #125.