lucidrains / enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
MIT License
435 stars 82 forks source link

Hard coded input sequence length to the transformer blocks with using use_tf_gamma = True #32

Closed zhhhhahahaha closed 11 months ago

zhhhhahahaha commented 1 year ago

Hi! Thanks for your amazing code. I am trying to use the pre-trained model but I found out that when I set the use_tf_gamma = True, I can only use the precomputed gamma positions for the input sequence of length 1536, will you fix that later? Also, the sanity check will fail. After running this

python test_pretrained.py 
Traceback (most recent call last):
  File "/home/ubuntu/enformer-pytorch/test_pretrained.py", line 11, in <module>
    corr_coef = enformer(
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 450, in forward
    x = trunk_fn(x)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 144, in forward
    return self.fn(x, **kwargs) + x
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 269, in forward
    positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma)
  File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 123, in get_positional_embed
    embeddings = torch.cat(embeddings, dim = -1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2047 but got size 3071 for tensor number 2 in the list.

The program will raise the problem I said above because the input sequence length for the transformer block is 1024 for the test sample.

lucidrains commented 1 year ago

@zhhhhahahaha ah bummer

yea i can fix it, though it will take up a morning

can you force it off for now? (by setting use_tf_gamma = False)

lucidrains commented 1 year ago

@zhhhhahahaha want to see if 0.8.4 fixes it on your machine?

zhhhhahahaha commented 1 year ago

Thanks for your work! But I think if we use input with different sequence lengths, we need to recompute the gamma position encoding because we have different $\mu_i$ and $\sigma$ for the enformer paper proposed gamma probability distribution function. However, it is not so meaningful to use the same parameter in enformer but choose different input sequence lengths. All in all, I think it is enough for using the pre-trained enformer's parameter with the same input sequence length, I will figure out myself if I need to use different input sequence lengths (maybe retrain the transformer block).

lucidrains commented 1 year ago

@zhhhhahahaha ah yea, we could expand the precomputed tf gammas for all sequence lengths from 1 - 1536, then index it out

lucidrains commented 1 year ago

i swear this is the last time i ever want to deal with tensorflow

lucidrains commented 1 year ago

@zhhhhahahaha if you have tensorflow environment installed and could get me that matrix, i can get this fixed in a jiffy

zhhhhahahaha commented 1 year ago

I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!

lucidrains commented 1 year ago

@johahi what do you think?

lucidrains commented 1 year ago

I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!

yea neither do i

lucidrains commented 1 year ago

the other option would be to check if https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.xlogy.html is equivalent to the tensorflow xlogy

then we use jax2torch

johahi commented 1 year ago

@lucidrains i'll try jax2torch, will let you know if that works! i don't know if the model performs well when it is used with cropped sequences, so just the tf-gammas for the original length of 1536 were fine for my use case...

johahi commented 1 year ago

@lucidrains from quick tests it seems like jax and torch have the same xlogy implementation (result after xlogy is allclose between them, but not between jax and tf or pt and tf), so this won't help, unfortunately :disappointed:

lucidrains commented 1 year ago

@johahi oh bummer, we'll just let the tf gamma hack work only for 1536 sequence length then

lucidrains commented 1 year ago

@johahi thanks for checking!