f-dangel / curvlinops

scipy linear operators for the Hessian, Fisher/GGN, and more in PyTorch
https://curvlinops.readthedocs.io/en/latest/
MIT License
17 stars 6 forks source link

Feature request: Trace estimation algorithms #30

Closed f-dangel closed 1 year ago

f-dangel commented 1 year ago

The trace is often used to summarize curvature matrices in second-order methods or for generalization metrics.

I could not find libraries that provide trace estimation methods for scipy.sparse.LinearOperators. The closest library is Nico's matfree which has Hutchinson trace estimation for JAX. pyhessian has Hutchinson trace estimation in PyTorch, but does not use a LinearOperator interface and only considers the Hessian.

So it would be useful to offer trace estimation through a scipy-based linear operator interface through this library.

Possible algorithms are:

andres-fr commented 1 year ago

Brilliant!

I also want to point out that Hutchinson can be extended to compute the trace of H^2, as follows:

Sketch of classical Hutchinson to estimate tr(H):

result = 0
for _ in range(N):
    x = sample_rademacher()
    x = l2_normalize(x)
    result += x @  hessian_vector_product(H, x)  # x.T @ H @ x
result /= N

Since tr(H) = <H, I> = <H, mean(v @ v.T)> where v are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I" all equal N, so we just divide by N.

Now, the only modification we need to compute tr(H^2), is to replace the x.T @ H @ x operation with x.T @ H @ H @ x. This can be achieved at no extra cost in memory or computation by changing a single line, namely:

result += l2_norm_squared(hessian_vector_product(H, x) )

Since || Hx ||_2^2 equals indeed x.T @ H @ H @ x.


I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since H^2 is PSD and H is generally not. And it is trivial to implement with zero overhead.

andres-fr commented 1 year ago

Here is my current implementation, hope it helps:

class HutchinsonSquared:
    """
    """

    @staticmethod
    def rademacher(dims, nonzero_idxs, dtype=np.float32):
        """
        """
        if nonzero_idxs is None:
            sample = (((np.random.rand(dims) < 0.5)) * 2 - 1).astype(dtype)
        else:
            noise = (((np.random.rand(len(nonzero_idxs)) < 0.5)) * 2 - 1)
            sample = np.zeros(dims, dtype=dtype)
            sample[nonzero_idxs] = noise.astype(dtype)
        #
        return sample

    @classmethod
    def hutchinson_squared(cls, hessian, num_iters=1, nonzero_idxs=None):
        """
        :returns: An estimate of ``tr(H^2)`` for the diagonal entries
          corresponding to ``nonzero_idxs``.
        """
        h, w = hessian.shape
        result = 0
        for i in range(num_iters):
            sample = cls.rademacher(w, nonzero_idxs, hessian.dtype)
            sample = hessian @ sample
            result += (sample * sample).sum()
        # we don't normalize sample vectors, since Rademacher guarantees
        # that all diagonal entries of R @ R.T will equal num_iters.
        # So to approximate identity, just divide by num_iters.
        result /= num_iters
        return result
f-dangel commented 1 year ago

Brilliant!

I also want to point out that Hutchinson can be extended to compute the trace of H^2, as follows:

Sketch of classical Hutchinson to estimate tr(H):

result = 0
for _ in range(N):
    x = sample_rademacher()
    x = l2_normalize(x)
    result += x @  hessian_vector_product(H, x)  # x.T @ H @ x
result /= N

Since tr(H) = <H, I> = <H, mean(v @ v.T)> where v are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I" all equal N, so we just divide by N.

Now, the only modification we need to compute tr(H^2), is to replace the x.T @ H @ x operation with x.T @ H @ H @ x. This can be achieved at no extra cost in memory or computation by changing a single line, namely:

result += l2_norm_squared(hessian_vector_product(H, x) )

Since || Hx ||_2^2 equals indeed x.T @ H @ H @ x.

I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since H^2 is PSD and H is generally not. And it is trivial to implement with zero overhead.

Interesting. I know there are estimation techniques for the traces of a matrix function F(A) (the so-called generalized trace, see the paragraph above Equation 2.1).

f-dangel commented 1 year ago

How should an interface for Hutchinson look like? Suggestion:

def hutchinson_trace(
    A: LinearOperator, 
    num_samples: int,
    sampling: str = 'rademacher', # could also be 'normal'
    return_samples: bool = False, # allow to access individual samples, e.g. to estimate the variance
) -> Union[float, Tuple[float, List[float]]:
    ...
    # only return the mean unless user requests individual samples

any thoughts?

f-dangel commented 1 year ago

I think a class-based approach may be better, especially for Hutch++ which has a pre-computation step (computing an orthonormal basis and the trace in its span).

class HutchinsonTrace:

    def __init__(self, A: LinearOperator):
         self.A = A

    def sample(self, distribution='rademacher') -> float:
        ...
f-dangel commented 1 year ago

Brilliant!

I also want to point out that Hutchinson can be extended to compute the trace of H^2, as follows:

Sketch of classical Hutchinson to estimate tr(H):

result = 0
for _ in range(N):
    x = sample_rademacher()
    x = l2_normalize(x)
    result += x @  hessian_vector_product(H, x)  # x.T @ H @ x
result /= N

Since tr(H) = <H, I> = <H, mean(v @ v.T)> where v are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I" all equal N, so we just divide by N.

Now, the only modification we need to compute tr(H^2), is to replace the x.T @ H @ x operation with x.T @ H @ H @ x. This can be achieved at no extra cost in memory or computation by changing a single line, namely:

result += l2_norm_squared(hessian_vector_product(H, x) )

Since || Hx ||_2^2 equals indeed x.T @ H @ H @ x.

I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since H^2 is PSD and H is generally not. And it is trivial to implement with zero overhead.

Quick correction: The l2_normalize is not required in Hutchinson's trace estimator.

andres-fr commented 1 year ago

Sorry for dropping the ball.