Closed zhhhhahahaha closed 11 months ago
@zhhhhahahaha ah bummer
yea i can fix it, though it will take up a morning
can you force it off for now? (by setting use_tf_gamma = False
)
@zhhhhahahaha want to see if 0.8.4 fixes it on your machine?
Thanks for your work! But I think if we use input with different sequence lengths, we need to recompute the gamma position encoding because we have different $\mu_i$ and $\sigma$ for the enformer paper proposed gamma probability distribution function. However, it is not so meaningful to use the same parameter in enformer but choose different input sequence lengths. All in all, I think it is enough for using the pre-trained enformer's parameter with the same input sequence length, I will figure out myself if I need to use different input sequence lengths (maybe retrain the transformer block).
@zhhhhahahaha ah yea, we could expand the precomputed tf gammas for all sequence lengths from 1 - 1536, then index it out
i swear this is the last time i ever want to deal with tensorflow
@zhhhhahahaha if you have tensorflow environment installed and could get me that matrix, i can get this fixed in a jiffy
I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!
@johahi what do you think?
I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!
yea neither do i
the other option would be to check if https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.xlogy.html is equivalent to the tensorflow xlogy
then we use jax2torch
@lucidrains i'll try jax2torch, will let you know if that works! i don't know if the model performs well when it is used with cropped sequences, so just the tf-gammas for the original length of 1536 were fine for my use case...
@lucidrains from quick tests it seems like jax and torch have the same xlogy implementation (result after xlogy is allclose between them, but not between jax and tf or pt and tf), so this won't help, unfortunately :disappointed:
@johahi oh bummer, we'll just let the tf gamma hack work only for 1536 sequence length then
@johahi thanks for checking!
Hi! Thanks for your amazing code. I am trying to use the pre-trained model but I found out that when I set the use_tf_gamma = True, I can only use the precomputed gamma positions for the input sequence of length 1536, will you fix that later? Also, the sanity check will fail. After running this
The program will raise the problem I said above because the input sequence length for the transformer block is 1024 for the test sample.