Open yxchng opened 4 weeks ago
The context_length is used for the causal mask in the mLSTM block. By looking at the code here:
context_length
if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device))
you can just set context_length to your maximum sequence length.
The
context_length
is used for the causal mask in the mLSTM block. By looking at the code here:you can just set
context_length
to your maximum sequence length.