MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
170 stars 22 forks source link

Implement soft thresholding #10

Closed kristian-georgiev closed 1 year ago

kristian-georgiev commented 1 year ago

https://github.com/MadryLab/trak/blob/b37289c75b4c9210084b2579ccc8f1b7438f1d20/trak/traker.py#L353-L355

kristian-georgiev commented 1 year ago

Thinking about it a bit more, this should not be done by the trak API (https://github.com/MadryLab/trak/commit/72fc94c0625a117103e7d362e6b7fd2224f4de54). To keep the TRAKer class simple, this should instead be done outside, as a post-processing step, since it requires cross-validation in the general case to determine the number of non-zero components (nnz below). A simple implementation is:

def soft_threshold_matrix(infls_, nnz):
    infls_soft = infls_.copy()

    for j in tqdm(range(infls_.shape[1])):
        bot_indices = np.argsort(np.abs(infls_soft[:,j]))[:-nnz]
        tau = np.abs(infls_soft[:,j])[bot_indices[-1]]
        infls_soft[bot_indices,j] = 0.
        infls_soft[infls_soft[:,j] > tau, j] -= tau
        infls_soft[infls_soft[:,j] < -tau, j] += tau

    return infls_soft