lucidrains / enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
MIT License
435 stars 82 forks source link

Initializing AttentionPool weights with 2 * Identity matrix. #21

Closed dohlee closed 1 year ago

dohlee commented 1 year ago

Hi, thanks for this great implementation. I'm learning a lot from it :)

I noticed that the commit https://github.com/lucidrains/enformer-pytorch/commit/0dc46e41de96bd739edba2cfaaa5e123990e9bc7 makes the internal AttentionPool weight initialized with randomly-sampled values rather than 2 * Identity-matrix, which is specified in the Enformer paper. (If there's something I am missing, please let me know!)

Indeed, this will not affect the performance of Enformer loaded with pretrained parameters, but I think it may lead to slightly worse (according to the paper) performance when trained from scratch.

Perhaps some simple manual weight initialization like

self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
self.to_attn_logits.data.zero_()
self.to_attn_logits.data.squeeze().fill_diagonal_(2)

will do.

If you think it'll be okay, please let me know then I'll open a PR right away.

Thanks again for this great repo!

Best, Dohoon

lucidrains commented 1 year ago

@dohlee Hi Dohoon! Thank you for catching this and glad to see another researcher applying attention to genomics! I've made the change here