flatironinstitute / nomad

Non-linear Matrix Decomposition library
Apache License 2.0
9 stars 1 forks source link

[Feature Request] Add kernel aggressive momentum NMD from Seraghiti et.al. (2023) #18

Open sfohr opened 2 months ago

sfohr commented 2 months ago

Which of these best describes your feature request:

Describe how the new feature would improve the library: I suggest the addition of the Aggressive Momentum NMD (A-NMD) kernel from Seraghiti et. al. (2023). Similiar to their Momentum 3-Block NMD, it is an expansion of the base model-free algorithm which alternates between constructing utility matrix $Z$ and the low-rank approximation $\Theta$ (subsequently called $L$). In contrast to the base model-free kernel, the authors add:

Leaving aside most convergence criteria, their matlab implementation uses the following input parameters:

With hyperparameter constraints:

And returns the low-rank matrix $L$.

The extrapolation of $Z$ and $L$ is equivalent to the extrapolation done in Momentum 3-Block NMD with additional heuristic tuning of $\beta$ and an accept-reject-mechanism on $Z$ and $L$, which is descriped in the following section.

Tuning Scheme:

Describe the solution you'd like My idea for the step method revolves around changing the order of operations to avoid parameter maxiter by moving the momentum step on $L$ and the parameter tuning to the beginning of each step (similar to #13):

def step(self) -> None:
  if self.elapsed_iterations > 0:
      # momentum step on L
      self.low_rank_candidate_L = apply_momentum(
          self.low_rank_candidate_L,
          self.previous_low_rank_candidate_L,
          self.momentum_beta,
      )
      if self.elapsed_iterations > 2:
          # parameter tuning and accept/reject steps
          if self.loss_is_decreasing():
              self.increase_momentum_parameters()
              self.accept_matrix_updates()
          else:
              self.decrease_momentum_parameters()
              self.reject_matrix_updates()

  utility_matrix_Z = construct_utility(
      self.low_rank_candidate_L, self.sparse_matrix_X
  )

  self.utility_matrix_Z = apply_momentum(
      utility_matrix_Z, self.utility_matrix_Z, self.momentum_beta
  )

  self.low_rank_candidate_L = find_low_rank(
      self.utility_matrix_Z,
      self.target_rank,
      self.low_rank_candidate_L,
      self.svd_strategy,
  )

  if self.tolerance is not None:
      self.loss = compute_loss(
          self.utility_matrix_Z, self.low_rank_candidate_L, LossType.FROBENIUS
      )

Apart from omiting maxiter, the different order of operations ensures that, after each step of the algorithm, the kernel's state actually holds the objects that produce the current loss, therefore ensuring integrity of information returned by per_iteration_diagnostic.

Utility functions Functions construct_utility, apply_momentum, find_low_rank and compute_loss can be reused from previous implementations.

accept_matrix_updates, respectively, reject_matrix_updates are simple class methods that copy the current matrices from slots self.low_rank_candidate_L and self.utility_matrix_Z to self.previous_low_rank_candidate_L and self.previous_utility_matrix_Z (accept case) or the other way round for the reject case.

self.increase_momentum_parameters and self.decrease_momentum_parameters handle the parameter tuning,

self.loss_is_decreasing first backups the last loss we used for parameter tuning, then calculates the current one and returns True if the loss is decreasing, False if increasing.

kernelInputTypes Algorithm-specific parameters:

kernelReturnTypes Returns the reconstruction, so adds nothing to the base class.

Describe alternatives you've considered I considered the regular order of operations, as descriped in the journal article:

def step(self) -> None:
  utility_matrix_Z = construct_utility(
      self.low_rank_candidate_L, self.sparse_matrix_X
  )

  self.utility_matrix_Z = apply_momentum(
      utility_matrix_Z, self.utility_matrix_Z, self.momentum_beta
  )

  self.low_rank_candidate_L = find_low_rank(
      self.utility_matrix_Z,
      self.target_rank,
      self.low_rank_candidate_L,
      self.svd_strategy,
  )

  if self.tolerance is not None:
      self.loss = compute_loss(
          self.utility_matrix_Z, self.low_rank_candidate_L, LossType.FROBENIUS
      )
  # momentum step on L
  if self.elapsed_iterations < (maxiter - 1):
    self.low_rank_candidate_L = apply_momentum(
        self.low_rank_candidate_L,
        self.previous_low_rank_candidate_L,
        self.momentum_beta,
    )
    if self.elapsed_iterations > 1:
        # parameter tuning and accept/reject steps
        if self.loss_is_decreasing():
            self.increase_momentum_parameters()
            self.accept_matrix_updates()
        else:
            self.decrease_momentum_parameters()
            self.reject_matrix_updates()

References

jsoules commented 2 months ago

This seems like a solid approach. I like your idea of reordering the operations--to be clear, in your proposed solution, you apply the momentum to L (and potentially update the momentum parameters) as the first part of every step, while in the paper, the algorithm does this at the end of every step. Sounds great, and I've worked through and convinced myself that the code for your proposed solution does the same thing as the code for your rejected solution, just in a more convenient way.

However, I wanted to discuss some discrepancies I noticed in the source work you're implementing. Specifically I'm comparing the pseudocode on p3 of the paper ("Algorithm 1") with the Matlab AMD implementation (https://gitlab.com/ngillis/ReLU-NMD/-/blob/main/A_NMD.m?ref_type=heads) and they don't seem to agree.

In the pseudocode:

In the matlab, however, it's as you've described:

So this is a little tricky. I think you are right to follow the Matlab implementation rather than the pseudocode, but I wanted to document the discrepancy and the choice.

Also, the Matlab implementation is using two definitions of "loss': when doing early stopping, it looks at the loss between the (rank-enforced) reconstruction; when deciding whether to update the momentum, it's looking at the (post-momentum) estimate. It isn't clear to me that this matters, but I wanted to point it out.

Another discrepancy between Matlab code and pseudocode is how they reset the maximum momentum (beta-bar). In the pseudocode:

In the Matlab code:

It kind of makes sense to set the max-beta to the one from 2 iterations ago, rather than last iteration, because we are assuming that the loss increased because we applied too much momentum. (= beta last time was too big.) Then again, this is updating the upper bound for beta, so maybe it's okay if the upper bound is something too big? And maybe the loss increased because of the low-rank reconstruction step, rather than using too much momentum?

Anyway, I suspect that these details don't make a consistent difference in practice--just choose something sensible, document what you choose, and you'll probably be fine.

Two final things: 1) Your proposal computes the loss only if self.tolerance is not None--which makes sense for other kernels where the loss is only used for early stopping. But I expect for this algorithm you need to compute the loss regardless of whether a tolerance was set.

2) I'm not really sure why the algorithm tracks the history of Z; doesn't Z effectively get overwritten every step when we do min(0, L)?

sfohr commented 2 months ago

Hi Jeff,

Thanks for writeup, you're right, it makes sense to properly document the discrepancies.

Also, the Matlab implementation is using two definitions of "loss': when doing early stopping, it looks at the loss between the (rank-enforced) reconstruction; when deciding whether to update the momentum, it's looking at the (post-momentum) estimate. It isn't clear to me that this matters, but I wanted to point it out.

The algorithm descriptions in the paper do not state how and when to calculate the "final" loss of the reconstruction $\Theta$ (Aggressive Momentum NMD), respectively, $W$ and $H$ (3-Block NMD). For Aggressive Momentum NMD they write: $$\Theta = \Theta^{k+1}$$

So $\Theta^{k+1}$ is our low rank candidate after applying the momentum term, which might result in a matrix whose rank differs from the target rank, I assume this is not the intended way to go (note that the same is true for 3B-NMD and they explicitly state this issue in the matlab implementation).

Further, they explicitly write down the post-momentum loss calculation to determine how to update $\beta$ and $\beta^{bar}$ and reject/accept $Z^{k+1}$ and $\Theta^{k+1}$ in the algorithm description. Both losses serve a different purpose - former is used to assess the accuracy of the final reconstruction (which needs to have the target rank), latter to check if the extrapolation using momentum terms was too extreme and therefore our loss increased. So, implicitly, increasing loss is attributed to the momentum terms.

Another discrepancy between Matlab code and pseudocode is how they reset the maximum momentum (beta-bar).
In the pseudocode: -when increasing momentum, [new] beta-bar is set to gamma-bar * [the current value of] beta-bar

I'll go with the pseudocode in this case, it's consistent with the paper in which they introduced this adaptation scheme for Nonnegative Matrix Factorization (https://arxiv.org/pdf/1805.06604). I tried both approaches on a simulated matrix, difference is negligible.

[Pseudocode:] when decreasing momentum, beta-bar is set to the previous value of beta (i.e. from one iteration prior) ... [matlab:] when decreasing momentum, beta-bar is set to the value of beta from two iterations prior ... It kind of makes sense to set the max-beta to the one from 2 iterations ago, rather than last iteration, because we are assuming that the loss increased because we applied too much momentum.

I think this is due to notation. In the pseudocode it's $\beta^{bar}=\beta{k-1}$ . The $\beta$ that induced the increase of the parameter update loss is $\beta{k}$, so $\beta_{k-1}$ is the last $\beta$ that decreased loss. In the matlab implementation it's beta_bar=beta_history(i-2). In the previous two lines, they first compute the new beta (for the next iteration) param.beta=param.eta*param.beta;, then append it to beta_history beta_history(i)=param.beta;, so the $\beta$ that induced the increase of the parameter update loss is beta_history(i-1) ($\beta_k$ in pseudocode) and the last $\beta$ that decreased the loss is beta_history(i-2) ($\beta_{k-1}$ in pseudocode). Basically a shift in indices due to notation.

Your proposal computes the loss only if self.tolerance is not None--which makes sense for other kernels where the loss is only used for early stopping. But I expect for this algorithm you need to compute the loss regardless of whether a tolerance was set.

That is not the case - parameter update losses are computed separately, the algorithm can run without computing the final loss but it needs to compute the parameter update loss to adapt $\beta$. I guess the confusion stems from improper naming of methods. Method self.loss_is_decreasing() first computes the current parameter update loss, checks if it's lower than self.parameter_update_loss (the last one), then updates self.parameter_update_loss and returns True or False. I also noted, in my proposed implementation, I never compute the parameter update loss for iteration 2. So I'll restructure this part into something like:

def step(self) -> None:
  if self.elapsed_iterations > 0:
      # momentum step on L
      self.low_rank_candidate_L = apply_momentum(
          self.low_rank_candidate_L,
          self.previous_low_rank_candidate_L,
          self.momentum_beta,
      )

      previous_parameter_update_loss = self.parameter_update_loss
      self.parameter_update_loss = self.compute_parameter_update_loss()

      if self.elapsed_iterations > 2:
          # parameter tuning and accept/reject steps
          if self.parameter_update_loss < previous_parameter_update_loss:
              self.increase_momentum_parameters()
              self.accept_matrix_updates()
          else:
              self.decrease_momentum_parameters()
              self.reject_matrix_updates()

This way, I think it's more transparent what's actually going on, also concerning the two loss types, albeit I once unnecessarily compute self.parameter_update_loss in the second iteration.

  • I'm not really sure why the algorithm tracks the history of Z; doesn't Z effectively get overwritten every step when we do min(0, L)?

We need to keep track of Z in case of increasing parameter update loss, which indicates that the extrapolation was too extreme. In this case we first update the hyperparameters with self.decrease_momentum_parameters() and then reject Z and L that resulted from a too extreme extrapolation step using self.reject_matrix_update(). Method self.reject_matrix_update() does the following:

def reject_matrix_updates(self) -> None:
    """Reject updates on utility matrix Z and low rank candidate L

    Rejects updates by copying the previous low rank candidate L to the
    current low rank candidate L and the previous utility matrix Z to current.

    Returns:
    None
    """
    np.copyto(self.utility_matrix_Z, self.previous_utility_matrix_Z)
    np.copyto(self.low_rank_candidate_L, self.previous_low_rank_candidate_L)

We need to do this because in the same iteration, we first compute the new Z (note that it's not directly assigned to self.utility_matrix_Z):

utility_matrix_Z = construct_utility(
    self.low_rank_candidate_L, self.sparse_matrix_X
)

Then do the extrapolation using self.utility_matrix_Z, which is the last Z that decreased the parameter update loss (as we rejected last Z):

self.utility_matrix_Z = apply_momentum(
    utility_matrix_Z, self.utility_matrix_Z, self.momentum_beta
)
jsoules commented 2 months ago

Thanks for writeup, you're right, it makes sense to properly document the discrepancies.

My pleasure. I know I'm coming into this without having looked at their algorithms very much, so I hope I'm still contributing something useful! :slightly_smiling_face:

[Pseudocode:] when decreasing momentum, beta-bar is set to the previous value of beta (i.e. from one iteration prior) ... [matlab:] when decreasing momentum, beta-bar is set to the value of beta from two iterations prior ...

I think this is due to notation.

:+1:

Your proposal computes the loss only if self.tolerance is not None

That is not the case [...]

Sounds good. Thanks for clarifying, & clarifying that they are intentionally using two definitions of loss. (Your proposed code implementation looks good also.)

  • I'm not really sure why the algorithm tracks the history of Z; doesn't Z effectively get overwritten every step when we do min(0, L)?

We need to keep track of Z in case of increasing parameter update loss [...]

Sure--I'm not sure what I was thinking here... I may have been thinking of the pseudocode (which suggests that an entire history of Z is kept, but of course it does, it's pseudocode...)