Open emma-mens opened 1 year ago
Stack trace to the decoder
past_key_values size 16
0%| | 0/30378 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 108, in <module>
main()
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 79, in main
results = evaluator.simple_evaluate(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 86, in simple_evaluate
results = evaluate(
^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 247, in evaluate
resps = getattr(lm, reqtype)([req.args for req in reqs])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 821, in fn
rem_res = getattr(self.lm, attr)(remaining_reqs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 210, in loglikelihood_rolling
string_nll = self._loglikelihood_tokens(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 296, in _loglikelihood_tokens
self._model_call(batched_inps), dim=-1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/models/huggingface.py", line 386, in _model_call
return self.model(inputs)["logits"]
^^^^^^^^^^^^^^^^^^
File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 944, in forward
outputs = self.model.decoder(
^^^^^^^^^^^^^^^^^^^
File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 618, in forward
# https://github.com/emma-mens/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L614
Traceback (most recent call last):
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 108, in <module>
main()
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/main.py", line 79, in main
results = evaluator.simple_evaluate(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 86, in simple_evaluate
results = evaluate(
^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/utils.py", line 160, in _wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/evaluator.py", line 247, in evaluate
resps = getattr(lm, reqtype)([req.args for req in reqs])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 821, in fn
rem_res = getattr(self.lm, attr)(remaining_reqs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 210, in loglikelihood_rolling
string_nll = self._loglikelihood_tokens(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/base.py", line 296, in _loglikelihood_tokens
self._model_call(batched_inps), dim=-1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/lm-evaluation-harness/lm_eval/models/huggingface.py", line 386, in _model_call
return self.model(inputs)["logits"]
^^^^^^^^^^^^^^^^^^
File "/gscratch/stf/emazuh/miniconda3/envs/llm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mmfs1/gscratch/scrubbed/emazuh/transformers/src/transformers/models/opt/modeling_opt.py", line 942, in forward
# https://github.com/emma-mens/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L936
[x] Usage in the decoder layer and the corresponding
past_key_values
usage[x] At the moment, the pytorch code seems to use a "giant?" tuple as the cache.
Sample logs in the above for loop. Log statement locations can be found in this commit. For the current structure of using a python tuple for the cache, maybe we don't need a separate simulation for testing compression algorithms? Or another thing could be to characterize the type of cache vectors that typically occur for different datasets and use those characterizations to define the simulation vectors that we attempt to compress? Also for initial KV compression, we could start with standard quantization approaches you've worked on before?
(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-350m --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 24 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 232 causal past_key_values size 232
(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-1.3b --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 24 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 232 causal past_key_values size 232
(llm) [emazuh@g3032 lm-evaluation-harness]$ python main.py --model hf-causal --model_args pretrained=facebook/opt-2.7b --tasks pile_stackexchange --device cuda:0 decoder past_key_values size 16 n_layers 32 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 attn 2 optdecoder 2 56 56 next_decoder_cache 296 causal past_key_values size 296