rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

support reassignment of module class attributes #22

Closed vieting closed 2 years ago

vieting commented 3 years ago

Currently, it is not supported to reassign class attributes like in the test case given here.

Just putting this here as a draft PR in order not to lose track of it.

albertz commented 3 years ago

Yea, I know. I was just lazy to implement this. Is this really used anywhere, though? (I'm just asking about the priority of this.)

vieting commented 3 years ago

I see. It is used in the Wav2Vec2Model. For now I just rewrote it so that it works.

vieting commented 2 years ago

This seems not really needed. Should we close the PR, @albertz ? I could create an issue if we want to keep track of this.

albertz commented 2 years ago

Ok let's close this for now, and reopen later when we need this.

vieting commented 2 years ago

Alright. Just putting the test case here in case we want to come back to it later:

def test_naming_reassign_attribute():
  n_batch, n_time = 3, 7
  n_in, n_out = 11, 13

  def model_func(wrapped_import, inputs: torch.Tensor):
    if typing.TYPE_CHECKING or not wrapped_import:
      import torch
    else:
      torch = wrapped_import("torch")

    class MyModel(torch.nn.Module):
      def __init__(self):
        super(MyModel, self).__init__()
        self.model = torch.nn.Linear(n_in, n_out)
        self.model = torch.nn.Sequential(self.model, torch.nn.GELU())

      def forward(self, x):
        return self.model(x)

    model = MyModel()
    out = model(inputs.transpose(1, 2))
    return out

  rnd = numpy.random.RandomState(42)
  x = rnd.normal(0., 1., (n_batch, n_in, n_time)).astype("float32")
  verify_torch_and_convert_to_returnn(model_func, inputs=x)
albertz commented 2 years ago

I think the PR branch can always be recovered though. The PR itself is technically also a branch and can be checked out directly.