emma-mens / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
0 stars 0 forks source link

Map out the code for key-value cache #1

Open emma-mens opened 1 year ago

emma-mens commented 1 year ago

(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-350m --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 24 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 232 causal past_key_values size 232

(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-1.3b --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 24 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 232 causal past_key_values size 232

(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-2.7b --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 32 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 296 causal past_key_values size 296

emma-mens commented 1 year ago

Stack trace to the decoder

past_key_values size 16
  0%|                                                                                                                                    | 0/30378 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 108, in <module>
    main()
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 79, in main
    results = evaluator.simple_evaluate(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 86, in simple_evaluate
    results = evaluate(
              ^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 247, in evaluate
    resps = getattr(lm, reqtype)([req.args for req in reqs])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 821, in fn
    rem_res = getattr(self.lm, attr)(remaining_reqs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 210, in loglikelihood_rolling
    string_nll = self._loglikelihood_tokens(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 296, in _loglikelihood_tokens
    self._model_call(batched_inps), dim=-1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/models/huggingface.py", line 386, in _model_call
    return self.model(inputs)["logits"]
           ^^^^^^^^^^^^^^^^^^
  File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 944, in forward
    outputs = self.model.decoder(
              ^^^^^^^^^^^^^^^^^^^
  File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 618, in forward
# https://github.com/emma-mens/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L614
Traceback (most recent call last):
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 108, in <module>
    main()
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 79, in main
    results = evaluator.simple_evaluate(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 86, in simple_evaluate
    results = evaluate(
              ^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 247, in evaluate
    resps = getattr(lm, reqtype)([req.args for req in reqs])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 821, in fn
    rem_res = getattr(self.lm, attr)(remaining_reqs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 210, in loglikelihood_rolling
    string_nll = self._loglikelihood_tokens(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 296, in _loglikelihood_tokens
    self._model_call(batched_inps), dim=-1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/models/huggingface.py", line 386, in _model_call
    return self.model(inputs)["logits"]
           ^^^^^^^^^^^^^^^^^^
  File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 942, in forward
# https://github.com/emma-mens/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L936