Closed Lucas-rbnt closed 4 days ago
Hi @Lucas-rbnt thanks for the effort on this followup PR. @atbenmurray could you please re-review the content here?
@Lucas-rbnt @atbenmurray I shall do so
I think this is fine now though the comments should be looked at the conflict resolved, then we can trigger the blossom tests. Thanks!
/build
This follows a previous PR (#7598).
In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using
torch.argsort
. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.In our implementation, we use
torch.multinomial
to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.
Description
Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.
Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation. The PR contains the architecture and associated unit tests.
Note: The output includes the prediction, which is a tensor of size: ($BS$, $N{tokens}$, $D$), and the associated mask ($BS$, $N{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think?
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.