rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
349 stars 130 forks source link

NAdam different between TF1.15 and TF2.3 #766

Closed JackTemaki closed 2 years ago

JackTemaki commented 2 years ago

I am not posting the full information here yet, because I need to collect some more things, but this is more a starting point for this issue:

Summary of important facts:

What happens:

I did many repetitions of this experiment with different learning rates, different batch sizes, tried step-based warmup, but it always resulted in the same behavior.

TF1: TF1_grads TF1_loss

TF2: TF2_grads TF2_loss

albertz commented 2 years ago

The gradient norms you mean?

When you make sure to start from the same model (e.g. via task="initialize_model"), you could verify whether it is really exactly the same or not (at least in the first batch).

JackTemaki commented 2 years ago

The gradient norms you mean?

If the norm is different the gradient is different. I just displayed the norm here, of course I looked at all values.

When you make sure to start from the same model (e.g. via task="initialize_model"), you could verify whether it is really exactly the same or not (at least in the first batch).

Ah sorry, I did not mention this. Yes this is the case, I start from the same initialized model. But the reason was rather to exclude the possibility that the initialization changed between TF1.15 and TF2.3 for any reason.

albertz commented 2 years ago

But is the gradient still the same (within some threshold) for the first mini batch? I.e. the difference accumulates slowly over time? Or is there some specific later step where it is clearly different, and from there on it becomes different? Or is it already different in the first mini batch? And where exactly is it different? Already the error signal from the loss? Or at what layer does it become different in backprop?

JackTemaki commented 2 years ago

But is the gradient still the same (within some threshold) for the first mini batch? I.e. the difference accumulates slowly over time? Or is there some specific later step where it is clearly different, and from there on it becomes different? Or is it already different in the first mini batch? And where exactly is it different? Already the error signal from the loss? Or at what layer does it become different in backprop?

The training logs show that the error signal is identical (first minibatch shows the same loss, so before the first gradient it is identical), and for the second minibatch there is already a noticeable difference.

I will try to come up with a toy task (artificial data, small network) that reproduces this.

For the last question I would need to have produce shorter logs, you can not see the raw values properly in the Tensorboard as it is now, so I can not see if for the first step the gradient is only different in e.g. the last LSTM layer but not for the linear transformation before the Softmax.

albertz commented 2 years ago

The same loss does not mean that the gradients are identical. Did you directly check the gradients (not just the norm)?

albertz commented 2 years ago

I would not just look at TensorBoard. I would write a small script which dumps the gradients to some file, maybe after each mini batch, maybe also other information, and then investigate that by hand.

JackTemaki commented 2 years ago

The same loss does not mean that the gradients are identical. Did you directly check the gradients (not just the norm)?

It does not, and I did not say that. I say the first entry (first minibatch before any update) shows identical loss, which means that the loss computation is not flawed.

The second entry shows diverging loss, so the first update was already different. For which location was different, I would indeed need to dump the gradients.

albertz commented 2 years ago

How much different is the loss in the second step? Or how much do the gradients really differ in the first step?

I would assume that the first gradients are still all very similar (up to some threshold) and the differences just accumulate over time.

This is then maybe not even a bug. Just different behavior in different TF versions, and you have bad luck here that this different behavior leads to convergence in one case but not in the other.

Although, as this is now a problem for you, you need to investigate more about what exactly is different in the behavior, and better understand it. Maybe replicate the old behavior somehow, or find other ways.

albertz commented 2 years ago

If the first gradients are already different, that's good. Because that means that it will be easy to debug. Then you should check at what point it becomes different. E.g. already at the loss gradient, or at some layer?

JackTemaki commented 2 years ago

E.g. already at the loss gradient, or at some layer?

I did not extract the gradients yet, but I opened the raw values of the summary file. There you can clearly see that in the first update step the gradients for the softmax bias, which directly depend on the error, are already off.

TF23:

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_mean"
simple_value: -0.0031194891780614853

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_stddev"
simple_value: 107.59998321533203

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_rms"
simple_value: 107.59998321533203

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_l2"
simple_value: 634.2925415039062

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_max"
simple_value: 22.030406951904297

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_min"
simple_value: -1243.468017578125

TF115

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_mean"
simple_value: -0.0035799972247332335

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_stddev"
simple_value: 106.58272552490234

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_rms"
simple_value: 106.58272552490234

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_l2"
simple_value: 628.2958984375

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_max"
simple_value: 22.547555923461914

tag: "grads/optimize/gradients/output/linear/add_bias_grad/tuple/summaries_grad_of_b/grad_of_b_min"
simple_value: -1227.8902587890625
JackTemaki commented 2 years ago

I assume the gradient statistics for Adam are differently initialized in TF1.15 and TF2.3

EDIT: I already checked the default parameters, but they did not change except for the epsilon, which is given fixed by the RETURNN config.

albertz commented 2 years ago

Ah interesting.

Can you disable Adam to see that Adam or some Adam specific things are not causing this? Just use standard SGD.

I'm not sure if this is maybe still within numerical fluctuations...

You can easily directly check the gradient of the loss then (w.r.t. the logits).

Maybe TF accumulates them differently between TF 1.15 and TF 2.3. For the bias, it will accumulate them over time (hopefully correctly discarding the padded frames as well).

JackTemaki commented 2 years ago

Okay, I did something wrong while managing my debug configs. The error is most likely that the NAdam implementation changed between TF1.15 and TF2.3

JackTemaki commented 2 years ago

So yes, the model is converging nicely when switching to Adam also with TF2.3. I already saw the NAdam discrepancy yesterday but thought I switched to Adam to exclude this mismatch.

The difference is the following implementation difference: Native code in TF1.15 (tensorflow/tensorflow/core/kernels/training_ops_gpu.cu.cc)

template <typename T>
__global__ void ApplyAdamKernel(int32 data_dim, T* var, T* m, T* v,
                                const T* const beta1_power_,
                                const T* const beta2_power_, const T* const lr_,
                                const T* const beta1_, const T* const beta2_,
                                const T* const epsilon_, const T* grad,
                                bool use_nesterov) {
  eigen_assert(blockDim.y == 1);
  eigen_assert(blockDim.z == 1);
  eigen_assert(gridDim.y == 1);
  eigen_assert(gridDim.z == 1);

  const T mul_factor = (*lr_) * sqrt(static_cast<T>(1.0) - (*beta2_power_)) /
                       (static_cast<T>(1.0) - (*beta1_power_));
  const T epsilon = (*epsilon_);
  const T beta1 = (*beta1_);
  const T one_minus_beta1 = static_cast<T>(1.0) - (beta1);
  const T one_minus_beta2 = static_cast<T>(1.0) - (*beta2_);
  const int32 stripe = gridDim.x * blockDim.x;

  for (int32 i = blockIdx.x * blockDim.x + threadIdx.x; i < data_dim;
       i += stripe) {
    auto m_i = m[i];
    auto g_i = grad[i];
    auto v_i = v[i];

    m_i += one_minus_beta1 * (g_i - m_i);
    v_i += one_minus_beta2 * (g_i * g_i - v_i);
    if (use_nesterov) {
      var[i] -= mul_factor * (m_i * beta1 + one_minus_beta1 * g_i) /
                (epsilon + sqrt(v_i));
    } else {
      var[i] -= mul_factor * m_i / (epsilon + sqrt(v_i));
    }

    m[i] = m_i;
    v[i] = v_i;
  }
}

Code in TF2.3 / Keras (keras/optimizer_v2/nadam.py)

  def _resource_apply_dense(self, grad, var, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = ((apply_state or {}).get((var_device, var_dtype))
                    or self._fallback_apply_state(var_device, var_dtype))

    m = self.get_slot(var, 'm')
    v = self.get_slot(var, 'v')

    g_prime = grad / coefficients['one_minus_m_schedule_new']
    m_t = (coefficients['beta_1_t'] * m +
           coefficients['one_minus_beta_1_t'] * grad)
    m_t = tf.compat.v1.assign(m, m_t, use_locking=self._use_locking)
    m_t_prime = m_t / coefficients['one_minus_m_schedule_next']
    v_t = (coefficients['beta_2_t'] * v +
           coefficients['one_minus_beta_2_t'] * tf.square(grad))
    v_t = tf.compat.v1.assign(v, v_t, use_locking=self._use_locking)
    v_t_prime = v_t / coefficients['v_t_prime_denominator']
    m_t_bar = (coefficients['one_minus_m_t'] * g_prime +
               coefficients['m_t_1'] * m_t_prime)
    var_t = var - coefficients['lr_t'] * m_t_bar / (
        tf.sqrt(v_t_prime) + coefficients['epsilon'])
    return tf.compat.v1.assign(var, var_t, use_locking=self._use_locking).op

The differences are:

albertz commented 2 years ago

I reported the problem for TF here: https://github.com/tensorflow/tensorflow/issues/53204