vmazashvili / Neural-Networks

Masktune Project reimplementation
0 stars 0 forks source link

modify masking method #2

Closed vmazashvili closed 1 month ago

vmazashvili commented 1 month ago

modify the masking method so it follows the paper more precisely.

`A key ingredient of our approach is a masking function $\mathcal{G}$ that is applied offline (i.e., after full training). The goal here is to construct a new masked dataset by concealing the most discriminative features in the input discovered by a model after full training. This should encourage the model to investigate new features with the masked training set during finetuning. As for $\mathcal{G}$, we adopt the xGradCAM~\citep{selvaraju2017grad}, which was originally designed for a visual explanation of deep models by creating rough localization maps based on the gradient of the model loss w.r.t. the output of a desired model layer. Given an input image of size $H \times W \times C$, xGradCAM outputs a localization map $\mathcal{A}$ of size $H \times W \times 1$, which shows the contribution of each pixel of the input image in predicting the most probable class, i.e., it calculates the loss by choosing the class with highest logit value (not the true label) as the target class. After acquiring the localization map, for each sample $(x_i, y_i)$, where $x_i \in X$ and $y_i \in Y$, we mask the locations with the most contribution as:

\begin{equation} \label{eqn:masking} \hat{xi} = \mathcal{T}(\mathcal{A}{x_i}; \tau) \odot xi; \ \ \ \ \mathcal{A}{xi} = \mathcal{G} (m\theta (x_i), y_i) \end{equation}

where $\mathcal{T}$ refers to a thresholding function by the threshold factor $\tau$ (i.e., $ \mathcal{T}=\mathbbm{1}{\mathcal{A}{xi} \le \tau }$), and $\odot$ denotes element-wise multiplication. As the resolution of $\mathcal{A}$ is typically coarser than that of the input data, $\mathcal{T}(\mathcal{A}{x_i})$ is up-sampled to the size of the input.

Procedurally, we first learn model $m{\theta}^{\text{initial}}$ using original unmasked training data $\mathcal{D}^{\text{initial}}$. Then we use $m{\theta}^{\text{initial}}$, $\mathcal{G}$ and $\mathcal{T}$ to create the masked set $\mathcal{D}^{\text{masked}}$. Finally, the fully trained predictor $m{\theta}^\text{initial}$ is tuned using $\mathcal{D}^{\text{masked}}$ to obtain $m{\theta}^{\text{final}}$. %\fk{You should use $D^\text{masked}$ in general so it would not be italic. also if we name the first model $m^\text{initial}$, it is better to name the final model with some superscript as well.}

As for the masking step, any explainability approach can be applied (note that some may have more computational complexity, such as ScoreCAM~\citep{wang2020score}). We use xGradCAM~\citep{selvaraju2017grad} as it is fast and produces relatively denser heat-maps than other methods~\citep{srinivas2019full, selvaraju2017grad, wang2020score}.`