Closed f-dangel closed 1 year ago
Brilliant!
I also want to point out that Hutchinson can be extended to compute the trace of H^2
, as follows:
Sketch of classical Hutchinson to estimate tr(H)
:
result = 0
for _ in range(N):
x = sample_rademacher()
x = l2_normalize(x)
result += x @ hessian_vector_product(H, x) # x.T @ H @ x
result /= N
Since tr(H) = <H, I> = <H, mean(v @ v.T)>
where v
are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I
" all equal N
, so we just divide by N
.
Now, the only modification we need to compute tr(H^2)
, is to replace the x.T @ H @ x
operation with x.T @ H @ H @ x
. This can be achieved at no extra cost in memory or computation by changing a single line, namely:
result += l2_norm_squared(hessian_vector_product(H, x) )
Since || Hx ||_2^2
equals indeed x.T @ H @ H @ x
.
I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since H^2
is PSD and H
is generally not. And it is trivial to implement with zero overhead.
Here is my current implementation, hope it helps:
class HutchinsonSquared:
"""
"""
@staticmethod
def rademacher(dims, nonzero_idxs, dtype=np.float32):
"""
"""
if nonzero_idxs is None:
sample = (((np.random.rand(dims) < 0.5)) * 2 - 1).astype(dtype)
else:
noise = (((np.random.rand(len(nonzero_idxs)) < 0.5)) * 2 - 1)
sample = np.zeros(dims, dtype=dtype)
sample[nonzero_idxs] = noise.astype(dtype)
#
return sample
@classmethod
def hutchinson_squared(cls, hessian, num_iters=1, nonzero_idxs=None):
"""
:returns: An estimate of ``tr(H^2)`` for the diagonal entries
corresponding to ``nonzero_idxs``.
"""
h, w = hessian.shape
result = 0
for i in range(num_iters):
sample = cls.rademacher(w, nonzero_idxs, hessian.dtype)
sample = hessian @ sample
result += (sample * sample).sum()
# we don't normalize sample vectors, since Rademacher guarantees
# that all diagonal entries of R @ R.T will equal num_iters.
# So to approximate identity, just divide by num_iters.
result /= num_iters
return result
Brilliant!
I also want to point out that Hutchinson can be extended to compute the trace of
H^2
, as follows:Sketch of classical Hutchinson to estimate
tr(H)
:result = 0 for _ in range(N): x = sample_rademacher() x = l2_normalize(x) result += x @ hessian_vector_product(H, x) # x.T @ H @ x result /= N
Since
tr(H) = <H, I> = <H, mean(v @ v.T)>
wherev
are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I
" all equalN
, so we just divide byN
.Now, the only modification we need to compute
tr(H^2)
, is to replace thex.T @ H @ x
operation withx.T @ H @ H @ x
. This can be achieved at no extra cost in memory or computation by changing a single line, namely:result += l2_norm_squared(hessian_vector_product(H, x) )
Since
|| Hx ||_2^2
equals indeedx.T @ H @ H @ x
.I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since
H^2
is PSD andH
is generally not. And it is trivial to implement with zero overhead.
Interesting. I know there are estimation techniques for the traces of a matrix function F(A)
(the so-called generalized trace, see the paragraph above Equation 2.1).
How should an interface for Hutchinson look like? Suggestion:
def hutchinson_trace(
A: LinearOperator,
num_samples: int,
sampling: str = 'rademacher', # could also be 'normal'
return_samples: bool = False, # allow to access individual samples, e.g. to estimate the variance
) -> Union[float, Tuple[float, List[float]]:
...
# only return the mean unless user requests individual samples
any thoughts?
I think a class-based approach may be better, especially for Hutch++ which has a pre-computation step (computing an orthonormal basis and the trace in its span).
class HutchinsonTrace:
def __init__(self, A: LinearOperator):
self.A = A
def sample(self, distribution='rademacher') -> float:
...
Brilliant!
I also want to point out that Hutchinson can be extended to compute the trace of
H^2
, as follows:Sketch of classical Hutchinson to estimate
tr(H)
:result = 0 for _ in range(N): x = sample_rademacher() x = l2_normalize(x) result += x @ hessian_vector_product(H, x) # x.T @ H @ x result /= N
Since
tr(H) = <H, I> = <H, mean(v @ v.T)>
wherev
are Rademacher iid vectors. Also, we don't normalize sample vectors since Rademacher guarantees that the diagonal entries of our "I
" all equalN
, so we just divide byN
.Now, the only modification we need to compute
tr(H^2)
, is to replace thex.T @ H @ x
operation withx.T @ H @ H @ x
. This can be achieved at no extra cost in memory or computation by changing a single line, namely:result += l2_norm_squared(hessian_vector_product(H, x) )
Since
|| Hx ||_2^2
equals indeedx.T @ H @ H @ x
.I believe this operation can be useful whenever we care about the magnitude of the trace or the top spectrum, since
H^2
is PSD andH
is generally not. And it is trivial to implement with zero overhead.
Quick correction: The l2_normalize
is not required in Hutchinson's trace estimator.
Sorry for dropping the ball.
The trace is often used to summarize curvature matrices in second-order methods or for generalization metrics.
I could not find libraries that provide trace estimation methods for
scipy.sparse.LinearOperator
s. The closest library is Nico'smatfree
which has Hutchinson trace estimation for JAX.pyhessian
has Hutchinson trace estimation in PyTorch, but does not use aLinearOperator
interface and only considers the Hessian.So it would be useful to offer trace estimation through a
scipy
-based linear operator interface through this library.Possible algorithms are:
NA-Hutch++(paper): I decided against implementing NA-Hutch++, since it does not offer memory savings over Hutch++. According to the paper, non-adaptive methods have practical benefits when used with batch-multiplies of the linear operator. The linear operators offered by this library however do only support efficientmatvec
s (matmat
s arefor
loops) and hence do not allow to leverage this benefit. Another point against implementing and maintaining this method is that according to themeyer2020hutch
paper, NA-Hutch++ "tends to perform slightly worse in our experiments."