tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 612 forks source link

Training fails when using Yogi extended with weight decay as optimizer #2716

Open YannPourcenoux opened 2 years ago

YannPourcenoux commented 2 years ago

System information

Describe the bug

I get the following error when training using model.fit() when using Yogi extended with weight decay as optimizer:

tfa.optimizers.extend_with_decoupled_weight_decay(tfa.optimizers.Yogi)(
            weight_decay=config.initial_weight_decay,
            learning_rate=learning_rate_schedule)`
    history = model.fit(
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/wandb/integration/keras/keras.py", line 163, in new_v2
    return old_v2(*args, **kwargs)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1178, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:850 train_function  *
        return step_function(self, iterator)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:840 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1285 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:833 run_step  **
        outputs = model.train_step(data)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:794 train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:161 minimize
        return super().minimize(
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:530 minimize
        return self.apply_gradients(grads_and_vars, name=name)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:189 apply_gradients
        return super().apply_gradients(grads_and_vars, name=name, **kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:667 apply_gradients
        return self._distributed_apply(strategy, grads_and_vars, name,
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:714 _distributed_apply
        update_op = distribution.extended.update(
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2578 update
        return self._replica_ctx_update(
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2457 _replica_ctx_update
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3048 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3055 _merge_call
        return merge_fn(self._strategy, *args, **kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2455 merge_fn  **
        return self.update(var, fn, merged_args, merged_kwargs, group=group)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2576 update
        return self._update(var, fn, args, kwargs, group)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3622 _update
        return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3628 _update_non_slot
        result = fn(*args, **kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:697 apply_grad_to_update_var  **
        update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
    /home/yann/anaconda3/envs/uvr-dev/lib/python3.9/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py:236 _resource_apply_dense
        return super()._resource_apply_dense(grad, var, apply_state=apply_state)

    TypeError: _resource_apply_dense() got an unexpected keyword argument 'apply_state'
bhack commented 2 years ago

Yes currently it is not handing the apply_state. If you want you could prepare a PR and extend the yogi test.