rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
347 stars 130 forks source link

Adding new inputs to pre-trained layers #421

Open michelwi opened 3 years ago

michelwi commented 3 years ago

Problem Statement

I would like to extend an already existing layer with an additional input. In my example I have trained an attention-based encoder-decoder model and now I would like to add an external LM to the inputs of the Softmax layer:

My returnn config for training the checkpoint was:

'output_prob': { 'class': 'softmax',
                 'dropout': 0.0,
                 'from': ['readout'], # input size: 500
                 'loss': 'ce',
                 'target': 'classes'}, # output size: 10025

Now I continue training with this extended layer while importing the existing checkpoint with the preload_from_files mechanics:

'output_prob': { 'bias_init': 0,
                 'class': 'softmax',
                 'custom_param_importer': 'subset',
                 'dropout': 0.0,
                 'forward_weights_init': 0,
                 'from': ['lm_output_prob', 'readout'], # input size: 10025 + 500
                 'loss': 'ce',
                 'target': 'classes'}, # output size: 10025

First Attempt

I am on commit face0c3be7d5336ac38c106ccc6606af70bd7ac9 wich I extended with the changes introduced in #412

The relevant variables in returnn/tf/util/basic.py function transform_param_axes_split_info_to_new_shape are

axes_split_info = [[10025, 500], [10025]]
new_shape = (500, 10025)
dim_diff = {10025: 10025}
new_parts = [10025, None]
Here is the Stack Trace I get ``` Unhandled exception in thread <_MainThread(MainThread, started 140417611712256)>, proc 7117. Thread current, main, <_MainThread(MainThread, started 140417611712256)>: (Excluded thread.) That were all threads. EXCEPTION Traceback (most recent call last): File "/work/asr4/michel/sandbow/returnn_meyer/rnn.py", line 11, in line: main() locals: main = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 645, in main line: execute_main_task() locals: execute_main_task = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 451, in execute_main_task line: engine.init_train_from_config(config, train_data, dev_data, eval_data) locals: engine = engine.init_train_from_config = > config = train_data = dev_data = eval_data = None File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1026, in init_train_from_config line: self.init_network_from_config(config) locals: self = self.init_network_from_config = > config = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1127, in init_network_from_config line: loader.load_now(session=self.tf_session) locals: loader = CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=>) loader.load_now = self = self.tf_session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 3151, in load_now line: value.assign_var(var=var, session=session) locals: value = value.assign_var = > var = session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2832, in assign_var line: self.custom_param_importer.assign_var(var=var, session=session) locals: self = self.custom_param_importer = self.custom_param_importer.assign_var = > var = session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2776, in assign_var line: self.layer.set_param_values_by_dict(values_dict=values_dict, session=session) locals: self = self.layer = self.layer.set_param_values_by_dict = > values_dict = {'W': array([[ 0.07109621, 0.01965252, 0.01508382, ..., 0.01965284, 0.02135382, 0.01964182], [-0.04140154, -0.12580605, -0.14898466, ..., -0.12578817, -0.12482522, -0.1258727 ], [ 0.11936312, 0.05931459, 0.03965379, ..., 0.05929022, 0.06279352, 0.05... session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/layers/base.py", line 897, in set_param_values_by_dict line: old_axes_splits = tf_util.transform_param_axes_split_info_to_new_shape( locals: old_axes_splits = tf_util = tf_util.transform_param_axes_split_info_to_new_shape = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py", line 184, in transform_param_axes_split_info_to_new_shape line: assert new_parts[j] > 0, debug_name locals: new_parts = [10025, -9525] j = 1 debug_name = "param 'output/rec/output_prob/W:0'", len = 34 AssertionError: param 'output/rec/output_prob/W:0' ```

Ok, so in the loop of https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L155 the first parts = [10025, 500] hits neither condition and only the second parts = [10025] sets dim_diff = {10025: 10025} so that we end up in https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L167 and then the heuristics fail.

Second Attempt

Then I added my case to the end of the loop of https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L155

  for new_dim, parts in zip(new_shape, axes_split_info):
    if len(parts) == 1:
      dim_diff[parts[0]] = new_dim
    elif len(set(parts)) == 1:  # all the same
      if new_dim % len(parts) == 0:
        dim_diff[parts[0]] = new_dim // len(parts)  # just a heuristic
    elif sum(parts[1:]) == new_dim: # added one input in front (see heurustic below)
      dim_diff[parts[0]] = 0
      dim_diff.update({dim:dim for dim in parts[1:]})
    elif new_dim in parts: # all inputs are new except one
      dim_diff.update({dim:0 for dim in parts})
      dim_diff[new_dim] = new_dim

Now I get the following variables

axes_split_info = [[10025, 500], [10025]]
new_shape = (500, 10025)
dim_diff = {10025: 10025, 500: 500}
new_parts = [10025, 500]
And here is the new Stack Trace ``` Unhandled exception in thread <_MainThread(MainThread, started 139845223368448)>, proc 28243. Thread current, main, <_MainThread(MainThread, started 139845223368448)>: (Excluded thread.) That were all threads. EXCEPTION Traceback (most recent call last): File "/work/asr4/michel/sandbow/returnn_meyer/rnn.py", line 11, in line: main() locals: main = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 645, in main line: execute_main_task() locals: execute_main_task = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/__main__.py", line 451, in execute_main_task line: engine.init_train_from_config(config, train_data, dev_data, eval_data) locals: engine = engine.init_train_from_config = > config = train_data = dev_data = eval_data = None File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1026, in init_train_from_config line: self.init_network_from_config(config) locals: self = self.init_network_from_config = > config = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/engine.py", line 1127, in init_network_from_config line: loader.load_now(session=self.tf_session) locals: loader = CustomCheckpointLoader(filename='/work/asr3/michel/meyer/work/crnn/training/CRNNTrainingJob.737fYxfbkoCz/output/models/epoch.250', params_prefix='', load_if_prefix='', ignore_missing=True, network=>) loader.load_now = self = self.tf_session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 3151, in load_now line: value.assign_var(var=var, session=session) locals: value = value.assign_var = > var = session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2832, in assign_var line: self.custom_param_importer.assign_var(var=var, session=session) locals: self = self.custom_param_importer = self.custom_param_importer.assign_var = > var = session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/network.py", line 2776, in assign_var line: self.layer.set_param_values_by_dict(values_dict=values_dict, session=session) locals: self = self.layer = self.layer.set_param_values_by_dict = > values_dict = {'b': array([ 1.8866221 , -0.3177627 , 3.6245446 , ..., -0.3159841 , -0.31281865, -0.31379807], dtype=float32), 'W': array([[ 0.07109621, 0.01965252, 0.01508382, ..., 0.01965284, 0.02135382, 0.01964182], [-0.04140154, -0.12580605, -0.14898466, ..., -0.12578817, ... session = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/layers/base.py", line 897, in set_param_values_by_dict line: old_axes_splits = tf_util.transform_param_axes_split_info_to_new_shape( locals: old_axes_splits = tf_util = tf_util.transform_param_axes_split_info_to_new_shape = File "/work/asr4/michel/sandbow/returnn_meyer/returnn/tf/util/basic.py", line 189, in transform_param_axes_split_info_to_new_shape line: assert new_parts[0] > 0, debug_name locals: new_parts = [0, 500] debug_name = "param 'output/rec/output_prob/W:0'", len = 34 AssertionError: param 'output/rec/output_prob/W:0' ```

Now the first parts = [10025, 500] hits my newly defined third condition and sets dim_diff = {10025: 0, 500: 500}, then the second parts = [10025] overwrites it to dim_diff = {10025: 10025, 500: 500}. new_parts for the first dim should have been [0, 500] but due to the overwriting it is now [10025. 500] and we trigger https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L175 which fortunately saves us and sets new_parts = [0, 500] .

Now only the assertion new_parts[0] > 0 fails as expected.

Questions

albertz commented 3 years ago

The name of transform_param_axes_split_info_to_new_shape and new_shape is maybe confusing. It's actually the old shape here. But it's used the other way around. So in your case, it means you start with 2 input layers, and then remove one.

The way the output of transform_param_axes_split_info_to_new_shape is used in copy_with_new_split_axes requires actually that we put the 0 in there.

So what we want:

  assert_equal(transform_param_axes_split_info_to_new_shape([[10025, 500], [10025]], (500, 10025)), [[0, 500], [10025]])

You should add that to test_transform_param_axes_split_info_to_new_shape.

maybe transform_param_axes_split_info_to_new_shape is used elsewhere

Just check that. I think not.

Is it save to allow the case new_parts[0] == 0 ...

Why not? Also, allowing sth which was not allowed before can not possible break anything.

Now the first parts ... hits my newly defined third condition ..., then the second parts ... overwrites it ....

That's way too ugly. You should not depend on assuming a specific order of the axes for your heuristic to work. (Yes, you can already now create strange edge cases where it will break depending on the order, but ignore those. Do not make it explicitly depending on such behavior.)

I would add sth like this to the first loop:

elif new_dim in parts:
  dim_diff[new_dim] = new_dim

And in the second loop, in the sum ... != new_dim if-branch, I would add another check:

if new_dim in new_parts:
  new_parts = ...
else:
  ... (as before)

Also change assert new_parts[j] > 0 to assert new_parts[j] >= 0.

Can we rewrite the whole mechanism to not use a dict dim_diff ...

I'm quite sure this would break other cases.

Maybe instead of dict[int,int], it could be dict[int,set[int]].

However, I also would try to not make this too complicated.

Note that there will always be cases which this does not fully cover, no matter what you do. I think it's fine now to add one or two more heuristics to cover your case but we should not make it too complicated.

Maybe there is a better way how to infer the old_axes_splits or to copy over the parameters.

albertz commented 3 years ago

Is this resolved?

Note that we also don't really need to invest so much energy into making the heuristic work correct in all cases (which is anyway not possible). We could maybe instead also just invest some energy into making it the right way so that it always correct, in a clean and predictable way, without relying on such a heuristic. This should certainly be possible, right?

E.g. the problem this heuristic tries to solve is, to recover the old shape split information, which is not available at this point anymore. Maybe we could just store this in the checkpoint or somewhere else?