syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

ValueError: not enough values to unpack (expected 2, got 1) #11

Closed pathoncyp closed 1 year ago

pathoncyp commented 1 year ago

Hey,

Thank you for this great work! An error occurred when I used the model to generate text

pathoncyp commented 1 year ago

File "E:\RetNet-main-huggingface\retnet\modeling_retnet.py", line 368, in forward batch_size, seq_length = input_ids.shape ValueError: not enough values to unpack (expected 2, got 1)

Dune-Z commented 1 year ago

@pathoncyp hi A quick fix, add batch dimension to generated token

  generated.append(token)
  if early_stopping and (token == eos_token_id).all():
         break
  token = token.unsqueeze(0) # add this line
pathoncyp commented 1 year ago

@pathoncyp hi A quick fix, add batch dimension to generated token

  generated.append(token)
  if early_stopping and (token == eos_token_id).all():
         break
  token = token.unsqueeze(0) # add this line

Thank you very much