rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
349 stars 130 forks source link

Network Re-init after every epoch for no apparent reason #1109

Closed JackTemaki closed 2 years ago

JackTemaki commented 2 years ago

After every epoch I get:

reinit because network description differs. Diff: item 'encoder' dict differs:
  item 'out_shape' differ. self: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'2*(BLSTM-out-dim)'(2048)}}, other: {Dim{'data_time'[B(-1)]}, Dim{F'2*(BLSTM-out-dim)'(2048)}, Dim{B}}
  item 'subnetwork' dict differs:
    item 'blstm_stack' dict differs:
      item 'out_shape' differ. self: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'2*(BLSTM-out-dim)'(2048)}}, other: {Dim{'data_time'[B(-1)]}, Dim{F'2*(BLSTM-out-dim)'(2048)}, Dim{B}}
      item 'subnetwork' dict differs:
        item '0' dict differs:
          item 'out_shape' differ. self: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'2*(BLSTM-out-dim)'(2048)}}, other: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'2*(BLSTM-out-dim)'(2048)}}
          item 'subnetwork' dict differs:
            item 'bwd_lstm' dict differs:
              item 'out_shape' differ. self: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'BLSTM-out-dim'(1024)}}, other: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'BLSTM-out-dim'(1024)}}
              item 'subnetwork' dict differs:
                item 'Parameter.initial' dict differs:
                  item 'shape' differ. self: (Dim{F'BLSTM-out-dim'(1024)}, Dim{F'4*(BLSTM-out-dim)'(4096)}), other: (Dim{F'BLSTM-out-dim'(1024)}, Dim{F'4*(BLSTM-out-dim)'(4096)})
                item 'Parameter.initial_0' dict differs:
                  item 'shape' differ. self: (Dim{F'data_feature'(50)}, Dim{F'4*(BLSTM-out-dim)'(4096)}), other: (Dim{F'data_feature'(50)}, Dim{F'4*(BLSTM-out-dim)'(4096)})
                item 'Parameter.initial_1' dict differs:
                  item 'shape' differ. self: (Dim{F'4*(BLSTM-out-dim)'(4096)},), other: (Dim{F'4*(BLSTM-out-dim)'(4096)},)
                item 'output' dict differs:
                  item 'out_shape' differ. self: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'BLSTM-out-dim'(1024)}}, other: {Dim{'data_time'[B(-1)]}, Dim{B}, Dim{F'BLSTM-out-dim'(1024)}}
                item 'param_W' dict differs:
                  item 'shape'[1] differ. self: Dim{F'4*(BLSTM-out-dim)'(4096)}, other: Dim{F'4*(BLSTM-out-dim)'(4096)}
                item 'param_W_re' dict differs:
                  item 'shape'[0] differ. self: Dim{F'BLSTM-out-dim'(1024)}, other: Dim{F'BLSTM-out-dim'(1024)}
                  item 'shape'[1] differ. self: Dim{F'4*(BLSTM-out-dim)'(4096)}, other: Dim{F'4*(BLSTM-out-dim)'(4096)}
                  [....]

(shortened as the actual content is not relevant, all items are the same according to the printed text, full log is available under /work/asr4/rossenbach/sisyphus_work_folders/tts_asr_2021_work/i6_core/returnn/rasr_training/ReturnnRasrTrainingJob.I9DCEdv2fCp5/log.run.1)

The network is created via:

def get_network(epoch, **kwargs):
    nn.reset_default_root_name_ctx()
    net = construct_hybrid_network(epoch=epoch, **network_kwargs)
    return nn.get_returnn_config().get_net_dict_raw_dict(net)

But I think this is independent of the network, the important thing is that I do not use the epoch parameter anywhere, so the net should always be identical.

@Atticus1806 you have the same behavior in your setups, right?

albertz commented 2 years ago

How does construct_hybrid_network look like?

I assume that you create new dim tags each time? Then they are not the same, and thus the net dict is different.

Atticus1806 commented 2 years ago

Yes, for me the same happens. My construct network looks like this (I assume @JackTemaki 's looks simiar):

def construct_network(
  epoch: int,
  net_module: nn.Module,
  phoneme_data: nn.Data,  # phoneme labels
  duration_data: nn.Data,  # durations
  label_data: nn.Data,  # speaker labels
  audio_data: nn.Data,  # target speech
  time_dim: nn.Dim,  # phoneme time dim
  label_time_dim: nn.Dim,  # speaker_label time
  speech_time_dim: nn.Dim,  # audio features time
  duration_time_dim: nn.Dim,  # durations time
  speaker_prior: Optional[nn.Data],  # VAE speaker prior
  prior_time: Optional[nn.Dim],  # VAE speaker prior time
  pitch: Optional[nn.Data],  # Pitch information
  pitch_time: Optional[nn.Dim],  # Pitch information
  **kwargs
):
  net = net_module(**kwargs)
  out = net(
    text=nn.get_extern_data(phoneme_data),
    durations=nn.get_extern_data(duration_data),
    speaker_labels=nn.get_extern_data(label_data),
    target_speech=nn.get_extern_data(audio_data),
    speaker_prior=nn.get_extern_data(speaker_prior),
    pitch=nn.get_extern_data(pitch),
    time_dim=time_dim,
    label_time=label_time_dim,
    speech_time=speech_time_dim,
    duration_time=duration_time_dim,
    prior_time=prior_time,
    pitch_time=pitch_time,
  )
  out.mark_as_default_output()

  return net
albertz commented 2 years ago

For example the Dim{F'BLSTM-out-dim'(1024)}, I assume it is created inside that net_module call. But that means for each net_module call, you would get a different dim.

albertz commented 2 years ago

To conclude, I would not call the net_module over and over, and cache it instead, and only recreate it when needed.

albertz commented 2 years ago

I have another idea: When comparing the old net dict to the new net dict, it could try to map the different old vs new dim tags cleverly in some way. When it can find a unique mapping, and then the net dicts are the same, then we can treat the whole net dict as the same. Finding such mapping can be tricky when sets are involved, e.g. {old_dim1, old_dim2} == {new_dim1, new_dim2}, it's not clear how to map them. But all dim tags should also occur somewhere not in sets, e.g. for specifying the shape of params it must be a list or tuple, or new dim tags are always somewhere as out_dim option, so actually it should always be unique.

albertz commented 2 years ago

What is the state here? Do you have some other workaround now? Or do you cache it as I suggest initially? What about my other suggestion?

albertz commented 2 years ago

Ok, this should be fixed now. Please try.