lucidrains / h-transformer-1d

Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning
MIT License
154 stars 21 forks source link

Sequence length issue when `causal = True` #8

Closed jaak-s closed 3 years ago

jaak-s commented 3 years ago

Found that for the causal attention setup the model still has an issue when not using maximum seq length input:

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 2,                 # depth
    causal = True,            # autoregressive or not
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128           # block size
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Gives the following error:

~/miniconda3/lib/python3.7/site-packages/rotary_embedding_torch/rotary_embedding_torch.py in apply_rotary_emb(freqs, t, start_index)
     43     assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
     44     t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
---> 45     t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
     46     return torch.cat((t_left, t, t_right), dim = -1)
     47 

RuntimeError: The size of tensor a (8192) must match the size of tensor b (8000) at non-singleton dimension 1

PS Sorry for posting several bug reports in a short time :)

lucidrains commented 3 years ago

@jaak-s no problem, thank you for catching these! :pray: https://github.com/lucidrains/h-transformer-1d/releases/tag/0.0.10