rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
349 stars 130 forks source link

Problems sharing LSTM layer parameters, when preloading from file #945

Closed aleksglushko closed 2 years ago

aleksglushko commented 2 years ago

Problem

Can't share weights for the LSTM ('iLMT_s') layer during training with optimization of the loop. Suggestion that is RETURNN can't find the proper name scopes for the variables that are optimized.

Weights are shared with custom function

def get_var(name, shape):
    from returnn.tf.util.basic import reuse_name_scope
    from returnn.tf.compat import v1 as tf
    with reuse_name_scope('', absolute=True):
        print('Reused variable: ', tf.get_variable(name, shape))
        return tf.get_variable(name, shape)

Mapping for the LSTMBlock is according to the checkpoint we restore from. (Weights in the checkpoint):

output/rec/s/rec/lstm_cell/bias (DT_FLOAT) [4000]
output/rec/s/rec/lstm_cell/kernel (DT_FLOAT) [3669,4000]

1 case:

when using the same layer, like'class': 'rnn_cell' and 'unit': 'LSTMBlock':

'iLMT_s': { 'L2': 0.0001, 'class': 'rnn_cell',  'from': ['prev:target_embed', 'zero_att'], 'n_out': 1000,
                 'reuse_params': {
                          'map': {
                                'kernel': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/kernel', _kwargs['shape'])},
                                'bias': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/bias', _kwargs['shape'])},
                                      }},
                          'unit': 'LSTMBlock',
                 },

log path: /u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/debug/LSTMBlock_sharing.log

2 case:

Use nativeLSTM unit instead of LSTMBlock. But the problem remains, since it can't find the proper names in the checkpoint:

'iLMT_s': { 'L2': 0.0001, 'class': 'rec', 'from': ['prev:target_embed', 'zero_att'], 'n_out': 1000, 'unit': 'nativelstm2',
                                             'reuse_params': {
                                                    'map': {
                                                               'b': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/b', _kwargs['shape'])},
                                                               'W': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/W', _kwargs['shape'])},
                                                               'W_re': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/W_re', _kwargs['shape'])}
                                                            }}},

log_path: /u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/debug/nativelstm_sharing.log

Network

 network = {
  'source',
  'convolution',
  'bLSTM's,

  'decision': {'class': 'decide', 'from': 'output', 'loss': 'edit_distance', 'target': 'classes', 'trainable': False},
  'enc_ctx': {'L2': 0.0001, 'activation': None, 'class': 'linear', 'from': 'encoder', 'n_out': 1024, 'trainable': False, 'with_bias': True},
  'enc_value': {'axis': 'F', 'class': 'split_dims', 'dims': (1, 2048), 'from': 'encoder', 'trainable': False},
  'encoder': {'class': 'copy', 'from': ['lstm5_fw', 'lstm5_bw'], 'trainable': False},
  'inv_fertility': {'activation': 'sigmoid', 'class': 'linear', 'from': 'encoder', 'n_out': 1, 'trainable': False, 'with_bias': False},

  'output': { 'class': 'rec',
              'from': [],
              'max_seq_len': "max_len_from('base:encoder')",
              'target': 'classes',
              'unit': { 
                       'accum_att_weights': { 'class': 'eval', 'eval': 'source(0) + source(1) * source(2) * 0.5',
                                                                'from': ['prev:accum_att_weights', 'att_weights', 'base:inv_fertility'], 
                                                                'out_type': {'dim': 1, 'shape': (None, 1)}},
                        'att': {'axes': 'except_batch', 'class': 'merge_dims', 'from': 'att0'},
                        'att0': {'base': 'base:enc_value', 'class': 'generic_attention', 'weights': 'att_weights'},
                        'att_weights': {'class': 'softmax_over_spatial', 'from': 'energy'},
                        'end': {'class': 'compare', 'from': 'output', 'kind': 'equal', 'value': 0},
                        'energy': {'activation': None, 'class': 'linear', 'from': 'energy_tanh', 'n_out': 1, 'with_bias': False},
                        'energy_in': {'class': 'combine', 'from': ['base:enc_ctx', 'weight_feedback', 's_transformed'], 
                                             'kind': 'add', 'n_out': 1024},
                        'energy_tanh': {'activation': 'tanh', 'class': 'activation', 'from': 'energy_in'},
                        'iLMT_output_prob': { 'L2': 0.0001, 'class': 'softmax', 'dropout': 0.3, 'from': 'iLMT_readout', 'loss': 'ce',
                                              'loss_opts': {'label_smoothing': 0.1, 'scale': 0.4}, 'target': 'classes',
                                              'reuse_params': { 
                                                   'map': { 'W': { 'custom': lambda **_kwargs: get_var('output/rec/output_prob/W', _kwargs['shape'])},
                                                   'b': { 'custom': lambda **_kwargs: get_var('output/rec/output_prob/b', _kwargs['shape'])}}}},
                        'iLMT_readout': { 'class': 'reduce_out', 'from': 'iLMT_readout_in', 'mode': 'max', 'num_pieces': 2,
                                          'reuse_params': { 
                                                   'map': { 'W': { 'custom': lambda **_kwargs: get_var('output/rec/readout/W', _kwargs['shape'])},
                                                   'b': { 'custom': lambda **_kwargs: get_var('output/rec/readout/b', _kwargs['shape'])}}}},
                        'iLMT_readout_in': { 'activation': None, 'class': 'linear', 'from': ['iLMT_s', 'prev:target_embed', 'zero_att'],
                                             'n_out': 1000, 'with_bias': True,
                                             'reuse_params': { 
                                                    'map': { 'W': { 'custom': lambda **_kwargs: get_var('output/rec/readout_in/W', _kwargs['shape'])},
                                                    'b': { 'custom': lambda **_kwargs: get_var('output/rec/readout_in/b', _kwargs['shape'])}}}},
                        'iLMT_s': { 'L2': 0.0001, 'class': 'rec', 'from': ['prev:target_embed', 'zero_att'], 'n_out': 1000, 'unit': 'nativelstm2',
                                              'reuse_params': {
                                                     'map': {
                                                                'b': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/b', _kwargs['shape'])},
                                                                'W': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/W', _kwargs['shape'])},
                                                                'W_re': {'custom': lambda **_kwargs: get_var('output/rec/s/rec/W_re', _kwargs['shape'])}
                                                             }}},
                        'output': {'beam_size': 12, 'class': 'choice', 'from': 'output_prob', 'initial_output': 0, 'target': 'classes'},
                        'output_prob': { 'L2': 0.0001, 'class': 'softmax', 'dropout': 0.3, 'from': 'readout', 'loss': 'ce',
                                                   'loss_opts': {'label_smoothing': 0.1}, 'target': 'classes'},
                        'readout': {'class': 'reduce_out', 'from': 'readout_in', 'mode': 'max', 'num_pieces': 2},
                        'readout_in': { 'activation': None, 'class': 'linear', 'from': ['s', 'prev:target_embed', 'att'], 'n_out': 1000,
                                                'with_bias': True},
                        's': {'L2': 0.0001, 'class': 'rnn_cell', 'from': ['prev:target_embed', 'prev:att'], 'n_out': 1000, 'unit': 'LSTMBlock'},
                        's_transformed': {'activation': None, 'class': 'linear', 'from': 's', 'n_out': 1024, 'with_bias': False},
                        'target_embed': { 'activation': None, 'class': 'linear', 'from': 'output', 'initial_output': 0, 'n_out': 621,
                                                     'with_bias': False},
                        'weight_feedback': { 'activation': None, 'class': 'linear', 'from': 'prev:accum_att_weights', 'n_out': 1024,
                                                     'with_bias': False},
                        'zero_att': {'class': 'eval', 'eval': 'tf.zeros_like(source(0))', 'from': 'att'}}},

preload_from_files = { 'baseline_LSTM_weights': { 'filename': '/work/asr3/zeineldeen/hiwis/glushko/setups-data/librispeech/2021-13-08--ilmt-att-sis/work/crnn/training/CRNNTrainingJob.hLLldrjunwA2/output/models/epoch.300',
                                    'ignore_missing': False,
                                   'init_for_train': True}}

config path:

/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/debug/ilmt_debug.config

albertz commented 2 years ago

What is the error you get in each case?

albertz commented 2 years ago

tf.get_variable is not about the checkpoint. The model is constructed (and this code is executed) before the checkpoint is loaded.

aleksglushko commented 2 years ago

In the first case:

  File "/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/returnn/returnn/tf/layers/base.py", line 2106, in ReuseParams.variable_custom_getter
    line: assert param_name in self.param_map
    locals:
      param_name = <local> 'lstm_cell/kernel', len = 16
      self = <local> <ReuseParams reuse_layer None, map {'kernel': <ReuseParams reuse_layer None, map None>, 'bias': <ReuseParams reuse_layer None, map None>}>
      self.param_map = <local> {'kernel': <ReuseParams reuse_layer None, map None>, 'bias': <ReuseParams reuse_layer None, map None>}

In the second:

  File "/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/returnn/returnn/tf/network.py", line 4020, in CustomCheckpointLoader.get_variable_value_map
    line: raise tf.errors.NotFoundError(
            node_def=None, op=None,
            message="CustomCheckpointLoader. could_not_find_map_list: %r" % (could_not_find_map_list,))
    locals:
      tf = <global> <module 'tensorflow' from '/work/tools/asr/python/3.8.0_tf_2.3-v1-generic+cuda10.1/lib/python3.8/site-packages/tensorflow/__init__.py'>
      tf.errors = <global> <module 'tensorflow._api.v2.errors' from '/work/tools/asr/python/3.8.0_tf_2.3-v1-generic+cuda10.1/lib/python3.8/site-packages/tensorflow/_api/v2/errors/__init__.py'>
      tf.errors.NotFoundError = <global> <class 'tensorflow.python.framework.errors_impl.NotFoundError'>
      node_def = <not found>
      op = <not found>
      message = <not found>
      could_not_find_map_list = <local> ['output/rec/s/rec/W', 'output/rec/s/rec/W_re', 'output/rec/s/rec/b'], _[0]: {len = 18}
NotFoundError: CustomCheckpointLoader. could_not_find_map_list: ['output/rec/s/rec/W', 'output/rec/s/rec/W_re', 'output/rec/s/rec/b']

In the seconde case, I thought the function that maps LSTMBlock -> NativeLSTM will mapkernel and biasinto W, W_re, b will work and then during the construction, variable will be loaded. And the thing that is not clear that sharing seemed to work, since this appears in the log:

Reused variable:  <tf.Variable 'output/rec/s/rec/W:0' shape=(2669, 4000) dtype=float32>
Reused variable:  <tf.Variable 'output/rec/s/rec/b:0' shape=(4000,) dtype=float32>
Reused variable:  <tf.Variable 'output/rec/s/rec/W_re:0' shape=(1000, 4000) dtype=float32>
layer root/output(rec-subnet-output)/'iLMT_readout_in' output: Data{'iLMT_readout_in_output', 
[T|'time:var:extern_data:classes'[B],B,F|F'iLMT_s:feature'(1000)]}
Reused variable:  <tf.Variable 'output/rec/readout_in/W:0' shape=(3669, 1000) dtype=float32>
Reused variable:  <tf.Variable 'output/rec/readout_in/b:0' shape=(1000,) dtype=float32>
layer root/output(rec-subnet-output)/'iLMT_readout' output: Data{'iLMT_readout_output', [T|'time:var:extern_data:classes'[B],B,F|F'iLMT_s:feature//2'(500)]}
layer root/output(rec-subnet-output)/'iLMT_output_prob' output: Data{'iLMT_output_prob_output', [T|'time:var:extern_data:classes'[B],B,F|F'iLMT_output_prob:feature-dense'(10025)]}
Reused variable:  <tf.Variable 'output/rec/output_prob/W:0' shape=(500, 10025) dtype=float32>
Reused variable:  <tf.Variable 'output/rec/output_prob/b:0' shape=(10025,) dtype=float32>
albertz commented 2 years ago

The first case is obvious, or not? You map kernel and bias in your config, but you should map lstm_cell/kernel and lstm_cell/bias instead.

In the second case, this automatic conversion LSTMBlock -> NativeLSTM only works without such custom variable maps.

You could first use the original config and only replace LSTMBlock by NativeLSTM2, and then load the old checkpoint, and just save it directly. When it loads, this converts the params, and then you have a checkpoint with NativeLSTM2. Then you don't need the conversion later on and you can use custom variable maps.

In general, avoid LSTMBlock, and just always use NativeLSTM2.

I guess this issue can be closed as there is no real issue on RETURNN side. But feel free to post any follow-up questions. Or also ask on Slack in the returnn channel.

aleksglushko commented 2 years ago

If i map like you said:

'map': {
      'lstm_cell/kernel' : {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/kernel', _kwargs['shape'])},
      'lstm_cell/bias' : {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/bias', _kwargs['shape'])},
            }

Then it doesn't see params. params = []

albertz commented 2 years ago

If i map like you said:

'map': {
      'lstm_cell/kernel' : {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/kernel', _kwargs['shape'])},
      'lstm_cell/bias' : {'custom': lambda **_kwargs: get_var('output/rec/s/rec/rnn/lstm_cell/bias', _kwargs['shape'])},
            }

Then it doesn't see params. params = []

What do you mean by that? Is there some error? What is the error?

If you want the params to be in the params dict of the layer, you could do:

In your lambda where you call the get_var function, just pass all the kwargs, like get_var(..., **kwargs).

Then you also get base_layer. Extend the code to:

var = tf.get_variable(name, shape)
base_layer.params[name] = var
return var

However, I don't think you actually need that.

aleksglushko commented 2 years ago

Yes, there is an error: /u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/debug/no_params.log

File "/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/returnn/returnn/tf/network.py", line 1101, in TFNetwork.add_layer
    line: layer = self._create_layer(name=name, layer_class=layer_class, **layer_desc)
    locals:
      layer = <not found>
      self = <local> <TFNetwork 'root/output(rec-subnet-output)' parent_layer=<RecLayer 'output' out_type=Data{[T|'time:var:extern_data:classes'[B],B], dtype='int32', sparse_dim=Dim{F'classes:sparse-dim'(10025)}}> train=<tf.Tensor 'globals/train_flag:0' shape=() dtype=bool>>
      self._create_layer = <local> <bound method TFNetwork._create_layer of <TFNetwork 'root/output(rec-subnet-output)' parent_layer=<RecLayer 'output' out_type=Data{[T|'time:var:extern_data:classes'[B],B], dtype='int32', sparse_dim=Dim{F'classes:sparse-dim'(10025)}}> train=<tf.Tensor 'globals/train_flag:0' shape=() dtype=bool>>>
      name = <local> 'iLMT_s', len = 6
      layer_class = <local> <class 'returnn.tf.layers.rec.RnnCellLayer'>
      layer_desc = <local> {'L2': 0.0001, 'n_out': 1000, 'reuse_params': <ReuseParams reuse_layer None, map {'lstm_cell/kernel': <ReuseParams reuse_layer None, map None>, 'lstm_cell/bias': <ReuseParams reuse_layer None, map None>}>, 'unit': 'LSTMBlock', '_network': <TFNetwork 'root/output(rec-subnet-output)' parent_layer=<..., len = 7
  File "/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/returnn/returnn/tf/network.py", line 1016, in TFNetwork._create_layer
    line: layer = layer_class(**layer_desc)
    locals:
      layer = <not found>
      layer_class = <local> <class 'returnn.tf.layers.rec.RnnCellLayer'>
      layer_desc = <local> {'L2': 0.0001, 'n_out': 1000, 'reuse_params': <ReuseParams reuse_layer None, map {'lstm_cell/kernel': <ReuseParams reuse_layer None, map None>, 'lstm_cell/bias': <ReuseParams reuse_layer None, map None>}>, 'unit': 'LSTMBlock', '_network': <TFNetwork 'root/output(rec-subnet-output)' parent_layer=<..., len = 10
  File "/u/glushko/setups/librispeech/2021-13-08--ilmt-att-sis/returnn/returnn/tf/layers/rec.py", line 4387, in RnnCellLayer.__init__
    line: assert params
    locals:
      params = <local> []
AssertionError
albertz commented 2 years ago

Ah, this assert is not necessary (or rather wrong) there. I pushed a fix for this. Can you try again?

aleksglushko commented 2 years ago

It is working now, thank you!