flatironinstitute / nomad

Non-linear Matrix Decomposition library
Apache License 2.0
11 stars 1 forks source link

[Feature Request] Add kernel momentum 3-block NMD from Seraghiti et.al. (2023) #13

Closed sfohr closed 8 months ago

sfohr commented 9 months ago

Which of these best describes your feature request:

Describe how the new feature would improve the library: I suggest the addition of the momentum 3-block NMD kernel, as described in Seraghiti et.al. (2023). It is an expansion of the current base model-free kernel as it also alternates between constructing utility matrix Z and the low-rank approximation. In contrast to the base model-free kernel:

Describe the solution you'd like Leaving aside most convergence criteria, their matlab implementation [2] uses the following input parameters:

And returns the low-rank factors W, H.

Following challenges come with their implementation:

My idea for the step method:

def step(self) -> None:
    if self.elapsed_iterations > 0:
        self.low_rank_candidate_L = apply_momentum(
            self.low_rank_candidate_L, 
            self.previous_low_rank_candidate_L, 
            self.momentum_beta
        )

    utility_matrix_Z = construct_utility(
      self.low_rank_candidate_L, 
      self.sparse_matrix_X
    )

    self.utility_matrix_Z = apply_momentum( 
        utility_matrix_Z, 
        self.utility_matrix_Z, 
        self.momentum_beta
    )

    self.previous_low_rank_candidate_L = self.factor_W @ self.factor_H

    self.factor_W = update_W(self.factor_H, self.utility_matrix_Z)
    self.factor_H = update_H(self.factor_W, self.utility_matrix_Z)

    self.low_rank_candidate_L = self.factor_W @ self.factor_H

Utility functions Function construct_utility is shared with the base model-free kernel, apply_momentum can be reused in the Accelerated Momentum NMD algorithm [1]. I'd rename base_model_free_util.py into model_free_util.py for functions that are shared across the model-free kernels and add, e.g. update_W to 3b_model_free_util.py.

kernelInputTypes

I'm unfamiliar with the way the input types are defined, specifically kernel-specific parameters. For the sake of simplicity, let's pretend 3B-NMD would be instantiated with L as well, would the following work?

class Momentum3BModelFreeInputType(KernelInputType):
    beta: float

KernelSpecificParameters = Union[float, int, Momentum3BModelFreeInputType]

kernelReturnTypes 3B-NMD returns W & H, so adds nothing to the base return type

References

jsoules commented 9 months ago

Hi,

Thanks for your clear and thorough proposal. I've now read it a few times and I agree that it's a good approach.

A few answers, suggestions, and acknowledgements:

Kernel initialized by W0 & H0, instead of L KernelSpecificParameters = ...

This is almost exactly right. The KernelSpecificParameters was implemented to allow you to do exactly what you're trying to do. (I just haven't had an opportunity to actually test it out yet.)

So what you would want to do is something like the following.

In src/fi_nomad/types/kernelInputTypes.py:

# imports...

class Momentum3BlockAdditionalParameters(NamedTuple):
    W0: FloatArrayType
    H0: FloatArrayType
    beta: float
    maxiters: int

KernelSpecificParameters = Union[float, int, Momentum3BlockAdditionalParameters]
# ...end of changes

(Note that I think the float, int in that union are just there to keep the linter happy--I don't recall the exact reason, I think the linter might have complained if the type was sufficiently constrained for it to actually infer the exact type elsewhere. Point is you might be able to delete one or both of them. But we can worry about that another time.

Then in src/fi_nomad/types/enums.py:

# in class KernelStrategy, add

MOMENTUM_3_BLOCK_MODEL_FREE = 5   # or whatever the next number is

In src/fi_nomad/util/factory_util.py:

# in instantiate_kernel(), add an elif:
#...
elif s == KernelStrategy.GAUSSIAN_MODEL_ROWWISE_VARIANCE:
    kernel = # etc
elif s == KernelStrategy.MOMENTUM_3_BLOCK_MODEL_FREE:
    kernel = Momentum3BlockModelFreeKernel(data_in, kernel_params)
# note you may need to cast kernel_params to Momentum3BlockAdditionalParameters to keep the linter happy

And in the constructor/ __init__ for your Momentum3BlockModelFreeKernel have a signature like

def __init__(self, indata: KernelInputType, custom_params: Momentum3BlockAdditionalParameters):

and you can initialize whatever values you need within this kernel from the custom parameters data object, including beta, maxiters, W0, H0, ...

Long story short, I don't think it will be necessary to alter the KernelInputType, but let me know if you run into trouble.

I'd rename base_model_free_util.py into model_free_util.py

I'm happy with that

and add 3b_model_free_util.py

Happy with this too, although I think there may be a restriction about files/modules not beginning with a number. If so, just spell it out.

let's pretend 3B-NMD would be instantiated with L as well

I haven't double-checked, but I think you can always just ignore the L. I suppose this could result in some unnecessary computation at the beginning, such as copying memory or finding the mean. If you want to avoid that, two further possibilities:

  1. you can use InitializationStrategy.KNOWN_MATRIX and pass in perhaps an empty or very small matrix as the "known" value. (However, there may be a check for this to make sure it matches the dimensions of the high-rank sparse nonnegative matrix, so that might not work.)
  2. alternatively, you can add an entry to the InitializationStrategy enum like IGNORE = 5 and update fi_nomad/util/initialization_util.py method initialize_candidate to do nothing if that's the chosen strategy.

Realistically, this is unlikely to affect performance enough to bother doing anything other than just ignoring the L and maybe initializing the kernel's actual L value when elapsed_iterations == 0.

My only other note is that I think it'd be better to spell out "Block" instead of "B" for the class/etc names.

--

Other than that, I think this is a great proposal and I look forward to seeing where you go with it. Please let me know what additional assistance you need or if you run into any snags.