Closed kristian-georgiev closed 1 year ago
Thinking about it a bit more, this should not be done by the trak
API (https://github.com/MadryLab/trak/commit/72fc94c0625a117103e7d362e6b7fd2224f4de54). To keep the TRAKer
class simple, this should instead be done outside, as a post-processing step, since it requires cross-validation in the general case to determine the number of non-zero components (nnz
below). A simple implementation is:
def soft_threshold_matrix(infls_, nnz):
infls_soft = infls_.copy()
for j in tqdm(range(infls_.shape[1])):
bot_indices = np.argsort(np.abs(infls_soft[:,j]))[:-nnz]
tau = np.abs(infls_soft[:,j])[bot_indices[-1]]
infls_soft[bot_indices,j] = 0.
infls_soft[infls_soft[:,j] > tau, j] -= tau
infls_soft[infls_soft[:,j] < -tau, j] += tau
return infls_soft
https://github.com/MadryLab/trak/blob/b37289c75b4c9210084b2579ccc8f1b7438f1d20/trak/traker.py#L353-L355