Open YannPourcenoux opened 2 years ago
System information
Describe the bug
I get the following error when training using model.fit() when using Yogi extended with weight decay as optimizer:
model.fit()
tfa.optimizers.extend_with_decoupled_weight_decay(tfa.optimizers.Yogi)( weight_decay=config.initial_weight_decay, learning_rate=learning_rate_schedule)`
history = model.fit( File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/wandb/integration/keras/keras.py", line 163, in new_v2 return old_v2(*args, **kwargs) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1178, in fit tmp_logs = self.train_function(iterator) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__ result = self._call(*args, **kwds) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function func_graph_module.func_graph_from_py_func( File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper raise e.ag_error_metadata.to_exception(e) TypeError: in user code: /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:850 train_function * return step_function(self, iterator) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:840 step_function ** outputs = model.distribute_strategy.run(run_step, args=(data,)) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1285 run return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica return self._call_for_each_replica(fn, args, kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica return fn(*args, **kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:833 run_step ** outputs = model.train_step(data) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:794 train_step self.optimizer.minimize(loss, self.trainable_variables, tape=tape) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:161 minimize return super().minimize( /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:530 minimize return self.apply_gradients(grads_and_vars, name=name) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:189 apply_gradients return super().apply_gradients(grads_and_vars, name=name, **kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:667 apply_gradients return self._distributed_apply(strategy, grads_and_vars, name, /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:714 _distributed_apply update_op = distribution.extended.update( /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2578 update return self._replica_ctx_update( /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2457 _replica_ctx_update return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3048 merge_call return self._merge_call(merge_fn, args, kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3055 _merge_call return merge_fn(self._strategy, *args, **kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2455 merge_fn ** return self.update(var, fn, merged_args, merged_kwargs, group=group) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2576 update return self._update(var, fn, args, kwargs, group) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3622 _update return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3628 _update_non_slot result = fn(*args, **kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:697 apply_grad_to_update_var ** update_op = self._resource_apply_dense(grad, var, **apply_kwargs) /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:236 _resource_apply_dense return super()._resource_apply_dense(grad, var, apply_state=apply_state) TypeError: _resource_apply_dense() got an unexpected keyword argument 'apply_state'
Yes currently it is not handing the apply_state. If you want you could prepare a PR and extend the yogi test.
apply_state
System information
Describe the bug
I get the following error when training using
model.fit()
when using Yogi extended with weight decay as optimizer: