Closed qAp closed 2 years ago
right the shape of x_affine
is (bs, sequence_length, input_dim, time_embed_dim)
. For example when I run the toy dataset with batch size 16, 32 context points and a time embedding dim of 12 I get x.shape = (16, 32, 7, 7), self.embed_weight.shape = (7, 12), x_affine.shape = (16, 32, 7, 12)
during the encoder pass. The final output of the Time2Vec layer is shape (batch_size, sequence_length, input_dim * time_embed_dim)
.
The comments in that section are incorrect, should be fixed in the next PR.
Shouldn't the shape of
x_affine
in this line: https://github.com/QData/spacetimeformer/blob/7e0caf17dd03e5d25e2766c4f7132805779bcc40/spacetimeformer/time2vec.py#L24be
[bs, sample, input_dim, embed_dim]
?