JuliaStats / NMF.jl

A Julia package for non-negative matrix factorization
Other
90 stars 34 forks source link

Fix#15: Avoidance of zero division #39

Closed ghost closed 4 years ago

ghost commented 4 years ago

I fixed #15 by adding a small positive constant lambda to the denominator of the update rules.

mschauer commented 4 years ago

I am a bit lost here about the design. Can we just make reasonable default values for lambda_w and lambda_h to get rid of if-statements like

       if lambda_h > zero(T)
             H[i,j] *= (WtQ[i,j] / (sW[i] + lambda_h))
         else
             H[i,j] *= (WtQ[i,j] / (sW[i] + delta))
         end
ghost commented 4 years ago

@mschauer Thank you for your review. That's a good suggestion. On the other hand, some users may overwrite the default values to zero. So I think it's safer to modify the code (line 9-31 in src/multupd.jl) as follows:

mutable struct MultUpdate{T}
    obj::Symbol     # objective :mse or :div
    maxiter::Int    # maximum number of iterations
    verbose::Bool   # whether to show procedural information
    tol::T          # change tolerance upon convergence
    lambda_w::T     # L1 regularization coefficient for W
    lambda_h::T     # L1 regularization coefficient for H

    function MultUpdate{T}(;obj::Symbol=:mse,
                            maxiter::Integer=100,
                            verbose::Bool=false,
                            tol::Real=cbrt(eps(T)),
                            lambda_w::Real=zero(T),
                            lambda_h::Real=zero(T)) where T

        obj == :mse || obj == :div || throw(ArgumentError("Invalid value for obj."))
        maxiter > 1 || throw(ArgumentError("maxiter must be greater than 1."))
        tol > 0 || throw(ArgumentError("tol must be positive."))
        lambda_w >= 0 || throw(ArgumentError("lambda_w must be non-negative."))
        lambda_h >= 0 || throw(ArgumentError("lambda_h must be non-negative."))
        if obj == :div
            lambda_w = max(lambda_w, sqrt(eps(T)))
            lambda_h = max(lambda_h, sqrt(eps(T)))
        end
        new{T}(obj, maxiter, verbose, tol, lambda_w, lambda_h)
    end
end

As a result, the extra if-statements can be deleted. Could I commit the change?

mschauer commented 4 years ago

Yes!

ghost commented 4 years ago

@mschauer I committed the change, but some checks were not successful. I'm not familiar with coverage/coveralls, so I don't know how to remedy the problem.

mschauer commented 4 years ago

17 of 17 new or added lines in 1 file covered.

All is well.

ghost commented 4 years ago

@mschauer Thank you for your comment and approval!

ghost commented 4 years ago

@ararslan Thank you for your approval!