Closed alexdemartos closed 4 months ago
@alexdemartos oh that's interesting, you are the second researcher within the last week that's asked about positions for cross attention
could you expand on what you are doing? i assume this is audio related, given we last interacted when you were using soundstream over at audiolm-pytorch!
is the source and target sequence aligned token by token? also, what type of relative positions are you using?
@alexdemartos you can rule out caching issues by disabling it with this flag
Hi @lucidrains , excellent memory that of yours! :)
Thanks for your quick response. Let me be more precise:
I am using a Transformer Decoder to auto-regressively predict real-valued phoneme-level duration, pitch and energy. I found your ContinuousAutoregressiveWrapper
pretty handy for the task, with just minor mods.
The context vector C
(encoded phonemes) and the targets (dur/pitch/energy) are of the same length L
. First I tried just setting rel_pos=True
, but as I got noisy inference predictions, I realized no positional information is added to context
, so I thought this might be the issue and tried adding rel_pos
to the cross-attention block. I thought this would be a good option given C
and targets are time-synchronous.
I am however experiencing the same behaviour as before adding rel_pos
to the cross-attention block: training works well, however the inference process seems broken.
I will try disabling caching as you suggested. Thanks for your time!
PD: Training vs inference pitch contours
@alexdemartos that's an interesting use case for ContinuousAutoregressiveWrapper
! yes, do let me know if disabling cache fixes it or not, and i'll throw some brain cycles into this issue. no guarantees though, as your use case is a bit off the beaten path
@alexdemartos how did turning off the caching go? i thought of a way to generalize relative positions within the attention blocks, so just let me know
Hi @lucidrains . Thanks for chasing this. Actually I didn't manage to turn off the flag you mentioned as it looks this is not available for the ContinuousAutoregressiveWrapper
:
I still didn't manage to get inference working. Tried implementing mask_prob
with large dropout (0.5) to prevent the exposure bias from teacher forcing, but this didn't seem to help significantly.
@alexdemartos oh that's right, continuous doesn't have kv cache just yet
ok, so the issue must be unrelated then
@alexdemartos what happens if you remove the relative positions altogether? perhaps give the source and target weight tied absolute positional embedding?
@alexdemartos what happens if you remove the relative positions altogether? perhaps give the source and target weight tied absolute positional embedding?
Testing this next :)
Update: Unfortunately no luck disabling rel_pos_bias
either. The results look slightly different, but still garbage.
Training: https://ibb.co/h74wysQ Inference: https://ibb.co/hKjDLcs
@alexdemartos oh, so it is unrelated to positioning then
Hi. It's been a long time, but I finally found the root of the issue. This doesn't relate to any issue regarding the current library implementation, but an issue on my own implementation of rel_pos_bias
in the cross-attention layer of the Transformer Decoder. Anyway, I just wanted to post it here for completion, just in case anyone might be interested in some similar implementation.
I was passing self.rel_pos
both to the self-attention and cross-attention layers. However, the RelativePositionBias
of the self-attention layer gets causal=True
from the decoder parameters, while causal
should be false
for the cross-attention layer. I solved the issue by just creating a separate self.rel_pos_cross = RelativePositionBias(...causal = False)
and passing that one to the 'c'
layer.
def __init__(
self,
...
):
...
self.rel_pos_cross = RelativePositionBias(scale = dim_head ** 0.5, causal = False, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
def forward(
...
):
...
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, rel_pos = self.rel_pos_cross, cache = next(iter_attn_cache, None), return_intermediates = True)
...
@alexdemartos nice! hope you trained a cool model in the end 😄
Hi,
I am planning to implement relative positional encoding for the
'c'
(cross-attention) AttentionLayer.In my case, the target and context sequences are of the same length and synchronous, so hopefully the relative positional encoding will help the attention to focus on the corresponding context part.
I tried passing
rel_pos
https://github.com/lucidrains/x-transformers/blob/b2979195ba496532eb0b7f52616eed178848d8af/x_transformers/x_transformers.py#L1338 to the'c'
block, as:While this works well for training, inference results are garbage. Maybe some caching issues? Any help is very much appreciated.
Thanks in advance!