Open Zettelkasten opened 2 years ago
I have a config like this:
'source_embed_raw': {'activation': None, 'class': 'linear', 'forward_weights_init': "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)", 'n_out': 512, 'with_bias': False, 'from': 'data:data'} # .... 'output': {'class': 'rec', # .. 'unit': { # .... 'output_prob': {'class': 'softmax', 'from': ['decoder'], 'reuse_params': { 'map': {'W': {'custom': (lambda reuse_layer, **kwargs: tf.transpose(reuse_layer.params["W"])), 'reuse_layer': 'target_embed_raw'}, 'b': None}}, 'target': 'classes', 'with_bias': True}, 'target_embed_raw': {'activation': None, 'class': 'linear', 'from': ['prev:output'], 'n_out': 512, 'reuse_params': {'map': {'W': {'reuse_layer': 'base:source_embed_raw'}, 'b': None}}, 'with_bias': False}, # .... }
where target_embed_raw shares from base:source_embed_raw and output_prob shares from target_embed_raw (via custom).
target_embed_raw
base:source_embed_raw
output_prob
custom
Before commit 2a1b840319b59f890958ccc8cad9639e3fea5efd, this worked fine. But that commit broke this. On that revision, I get the following error:
File "/home/frithjof/Documents/Lsh/returnn/returnn/tf/layers/base.py", line 1985, in ReuseParams.variable_custom_getter line: return self.custom_func( base_layer=base_layer, reuse_layer=self.reuse_layer, full_name=name, name=param_name, shape=shape, dtype=dtype, getter=getter, **kwargs) locals: self = <local> <ReuseParams reuse_layer <LinearLayer output/'target_embed_raw' out_type=Data{[T|'time:var:extern_data:classes'[B],B,F|F'target_embed_raw:feature-dense'(512)]}>, map None> self.custom_func = <local> <function test_reuse_params_map_custom_transitive_dependency.<locals>.<lambda> at 0x7fd7fb721700> base_layer = <local> <SoftmaxLayer output/'output_prob' out_type=Data{[T|'time:var:extern_data:classes'[B],B,F|F'output_prob:feature-dense'(3)]}> reuse_layer = <not found> self.reuse_layer = <local> <LinearLayer output/'target_embed_raw' out_type=Data{[T|'time:var:extern_data:classes'[B],B,F|F'target_embed_raw:feature-dense'(512)]}> full_name = <not found> name = <local> 'output/rec/output_prob/W', len = 24 param_name = <local> 'W' shape = <local> (512, 3) dtype = <local> tf.float32 getter = <local> <function _VariableStore.get_variable.<locals>._true_getter at 0x7fd7cc7c7430> kwargs = <local> {'initializer': <tensorflow.python.ops.init_ops.VarianceScaling object at 0x7fd7b4313e20>, 'regularizer': None, 'reuse': <_ReuseMode.AUTO_REUSE: 1>, 'trainable': True, 'collections': None, 'caching_device': None, 'partitioner': None, 'validate_shape': True, 'use_resource': None, 'synchronization'..., len = 12 File "/home/frithjof/Documents/Lsh/returnn/tests/test_TFNetworkLayer.py", line 3243, in test_reuse_params_map_custom_transitive_dependency.<locals>.<lambda> line: 'reuse_params': { 'map': {'W': {'custom': (lambda reuse_layer, **kwargs: tf.transpose(reuse_layer.params["W"])), 'reuse_layer': 'target_embed_raw'}, 'b': None}}, locals: reuse_layer = <local> <LinearLayer output/'target_embed_raw' out_type=Data{[T|'time:var:extern_data:classes'[B],B,F|F'target_embed_raw:feature-dense'(512)]}> kwargs = <local> {'base_layer': <SoftmaxLayer output/'output_prob' out_type=Data{[T|'time:var:extern_data:classes'[B],B,F|F'output_prob:feature-dense'(3)]}>, 'full_name': 'output/rec/output_prob/W', 'name': 'W', 'shape': (512, 3), 'dtype': tf.float32, 'getter': <function _VariableStore.get_variable.<locals>._true..., len = 18 tf = <global> <module 'tensorflow' from '/home/frithjof/.local/lib/python3.8/site-packages/tensorflow/__init__.py'> tf.transpose = <global> <function transpose_v2 at 0x7fd7fa1e3ca0> reuse_layer.params = <local> {} KeyError: 'W'
Note that sharing params is probably easier with the new name_scope logic.
name_scope
But I'm not arguing that this should not be fixed.
I have a config like this:
where
target_embed_raw
shares frombase:source_embed_raw
andoutput_prob
shares fromtarget_embed_raw
(viacustom
).Before commit 2a1b840319b59f890958ccc8cad9639e3fea5efd, this worked fine. But that commit broke this. On that revision, I get the following error: