Closed lucidrains closed 3 years ago
@cfoster0 it actually got rearranged from sequence
to sequence i x sequence j x heads
here, and then the vmap
takes care of the batch (i think)
Hmm. I don't know what you'd put in the mask if you wanted it to be causal, since it starts out as sequence
. Ideally I think you'd specify a sequence x sequence
square mask, and have it expand out the heads dimension.
@cfoster0 oh wait, are we doing causal here? i'm just accounting for variable lengthed sequences (non-causal)
Yeah it'd be nice to be able to do causal on the text side, for the non-CLS tokens. I don't think we have to worry about variable length sequences bc we can just pad the input data.
Ohh how would the causal side learn? Would it have an autogressive loss in addition to the contrastive?
I don't think it needs any special loss. Tokens would attend to tokens to the left of themselves, and since the CLS token can see everything else, it can aggregate whatever it needs for the contrastive task.
@cfoster0 ok, let's run with that! i added the triangular mask in the latest commit
Awesome! Lgtm 💯
Looks nice! One question on the masking:
How should one interpret the mask shape? I was expecting
batch x sequence x sequence
whereas this looks likebatch x sequence
.