togethercomputer / OpenChatKit

Apache License 2.0
9.01k stars 1.01k forks source link

RuntimeError: The size of tensor a (2048) must match the size of tensor b (2131) at non-singleton dimension 3 #92

Open lclfans opened 1 year ago

lclfans commented 1 year ago

Describe the bug When run $python inference/bot.py --model togethercomputer/Pythia-Chat-Base-7B --retrieval it report a RuntimeError: The size of tensor a (2048) must match the size of tensor b (2131) at non-singleton dimension 3

To Reproduce Steps to reproduce the behavior: 0 only use CPU for inference

  1. run cmd: $python inference/bot.py --model togethercomputer/Pythia-Chat-Base-7B --retrieval
  2. input below question:

    write a hello world in C language then error occurs.

Expected behavior should be no error

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

Smartphone (please complete the following information):

Additional context Traceback (most recent call last): detailed error output:

write a hello world in C language Traceback (most recent call last): File "/home/robili/ai/OpenChatKit/inference/bot.py", line 287, in main() File "/home/robili/ai/OpenChatKit/inference/bot.py", line 283, in main ).cmdloop() File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 138, in cmdloop stop = self.onecmd(line) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 217, in onecmd return func(arg) File "/home/robili/ai/OpenChatKit/inference/bot.py", line 151, in do_say output = self._model.do_inference( File "/home/robili/ai/OpenChatKit/inference/bot.py", line 93, in do_inference outputs = self._model.generate( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1326, in generate return self.sample( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1944, in sample outputs = self( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(args, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 619, in forward outputs = self.gpt_neox( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, *kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 511, in forward outputs = layer( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 319, in forward attention_layer_outputs = self.attention( File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(args, kwargs) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 153, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) File "/home/robili/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 220, in _attn attn_scores = torch.where(causal_mask, attn_scores, mask_value) RuntimeError: The size of tensor a (2048) must match the size of tensor b (2131) at non-singleton dimension 3

orangetin commented 1 year ago

0 only use CPU for inference

Are you using just the CPU for inference or CUDA 0 and CPU?

deepgithubforever commented 1 year ago

I have samsung m31 seems my phone was under attack so learning and devloping security

deepgithubforever commented 1 year ago

ANDRIOID

lclfans commented 1 year ago

0 only use CPU for inference

Are you using just the CPU for inference or CUDA 0 and CPU?

I changed code and just use CPU for inference

lclfans commented 1 year ago

code change like: image

image

deepgithubforever commented 1 year ago

Awsome

orangetin commented 1 year ago

0 only use CPU for inference

Are you using just the CPU for inference or CUDA 0 and CPU?

I changed code and just use CPU for inference

Okay, I'll try to reproduce this. Have you tried it without the --retrieval flag to see if that works? Looking at your log, it looks like it errors out before it gets to the retrieval part.

orangetin commented 1 year ago

@lclfans I'm not able to reproduce this specific error even with the --retrieval flag. OCK doesn't officially support CPU-only inference just yet.

Could you try replacing the contents of bot.py with the contents of this file and then run python inference/bot.py --model togethercomputer/Pythia-Chat-Base-7B --retrieval -r MAX_RAM (replace MAX_RAM with the maximum amount of ram you'd like to allocate) ?

The change you made to wikipedia.py looks good.

Jblauvs commented 1 year ago

I found that this was due to the tokenizer not having truncation and max_length set correctly. Once I set it for an appropriate amount I never saw this error again. You'll want to make sure the amount set here + your max output length <= the maximum positional embeddings of the model.

nd7141 commented 1 year ago

Hi @Jblauvs ,

Do you mind sharing here particular lines of code you changed?

lclfans commented 1 year ago

0 only use CPU for inference

Are you using just the CPU for inference or CUDA 0 and CPU?

I changed code and just use CPU for inference

Okay, I'll try to reproduce this. Have you tried it without the --retrieval flag to see if that works? Looking at your log, it looks like it errors out before it gets to the retrieval part.

no error without --retrieval flag

lclfans commented 1 year ago

I found that this was due to the tokenizer not having truncation and max_length set correctly. Once I set it for an appropriate amount I never saw this error again. You'll want to make sure the amount set here + your max output length <= the maximum positional embeddings of the model. hi @Jblauvs could you show your code change?

Jblauvs commented 1 year ago

@lclfans @nd7141 This is hacked in for now for my use case but I could come up with a PR given a bit of time. I have max_tokens set to 256. The total of max_tokens+max_length should be equal to the 2048 or less.

https://github.com/togethercomputer/OpenChatKit/blob/71dd823e963c8436d7e230ebf09ad8de93644163/inference/bot.py#L89 This should be changed to:

self._tokenizer(prompt, return_tensors='pt', max_length=1790, truncation=True)

Depending on your usage you may also want to change this: https://github.com/togethercomputer/OpenChatKit/blob/71dd823e963c8436d7e230ebf09ad8de93644163/inference/bot.py#L84 to:

        self._tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side='left')

The system prepends previous conversation to the prompt, so rather than snip on the right side you may want to snip the left side.

I can go into detail about the why if that's of interest.

Jblauvs commented 1 year ago

I'd offer up that the reason it occurs immediately with the --retrieval flag is that the context is then added to the prompt, which probably adds up to more than the maximum 2048 tokens and so it blows up. The same would happen if you kept talking to the bot, since the conversation is prepended to the prompt as well.