Closed pfeatherstone closed 11 months ago
Here is a repro:
lm = ContinuousTransformerWrapper( dim_in = 4, dim_out = 256+3, max_seq_len = 0, num_memory_tokens = 20, attn_layers = Decoder( dim = 512, depth = 4, heads = 4, rotary_pos_emb = True, attn_flash = True, use_scalenorm = True, attn_onnxable = True, shift_tokens = 1 ) ) x = torch.randn(2, 1024, 4) l = torch.randint(100, x.shape[1], size=(x.shape[0],)) m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1) x = lm(x, mask=m)
should be fixed
Here is a repro: