EleutherAI / lm-evaluation-harness

A framework for few-shot evaluation of language models.
https://www.eleuther.ai
MIT License
6.36k stars 1.68k forks source link

Issue with `state-spaces/transformerpp-2.7b` when generating #2182

Closed jhuang265 closed 2 weeks ago

jhuang265 commented 1 month ago

There appears to be an issue with the state-spaces/transformerpp-2.7b model (in the mamba_ssm family of models) which causes a problem when generating (Running generate_until requests). This doesn't happen for Running loglikelihood requests, so I think there might be a specific issue that relates to the underlying calls. This doesn't happen for any other models with the mamba architecture

The full stack-track is

Traceback (most recent call last):                                                                                                                                                                                    
  File "/home/venvs/project_path/bin/lm_eval", line 8, in <module>                                                                                                           
    sys.exit(cli_evaluate())                                                                                                                                                                                          
  File "/home/project_path/lm-evaluation-harness/lm_eval/__main__.py", line 382, in cli_evaluate                                                                             
    results = evaluator.simple_evaluate(                                                                                                                                                                              
  File "/home/project_path/lm-evaluation-harness/lm_eval/utils.py", line 397, in _wrapper                                                                                    
    return fn(*args, **kwargs)                                                                                                                                                                                        
  File "/home/project_path/lm-evaluation-harness/lm_eval/evaluator.py", line 296, in simple_evaluate                                                                         
    results = evaluate(                                                                                                                                                                                               
  File "/home/project_path/lm-evaluation-harness/lm_eval/utils.py", line 397, in _wrapper                                                                                    
    return fn(*args, **kwargs)                                                                                                                                                                                        
  File "/home/project_path/lm-evaluation-harness/lm_eval/evaluator.py", line 468, in evaluate                                                                                
    resps = getattr(lm, reqtype)(cloned_reqs)                                                                                                                                                                         
  File "/home/project_path/lm-evaluation-harness/lm_eval/models/huggingface.py", line 1249, in generate_until                                                                
    cont = self._model_generate(                                                                                                                                                                                      
  File "/home/project_path/lm-evaluation-harness/lm_eval/models/mamba_lm.py", line 119, in _model_generate                                                                   
    return self.model.generate(                                                                                                                                                                                       
  File "/home/project_path/mamba/mamba_ssm/utils/generation.py", line 260, in generate                                                                                       
    output = decode(                                                                                                                                                                                                  
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                     
    return func(*args, **kwargs)                                                                                                                                                                                      
  File "/home/project_path/mamba/mamba_ssm/utils/generation.py", line 221, in decode                                                                                         
    scores.append(get_logits(sequences[-1], inference_params))                                                                                                                                                        
  File "/home/project_path/mamba/mamba_ssm/utils/generation.py", line 184, in get_logits                                                                                     
    logits = model(                                                                                                                                                                                                   
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                  
    return self._call_impl(*args, **kwargs)                                                                                                                                                                           
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                          
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/home/project_path/mamba/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward                                                                                 
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)                                                                                                                       
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                  
    return self._call_impl(*args, **kwargs)                                                                                                                                                                           
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                          
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/home/project_path/mamba/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward                                                                                 
    hidden_states, residual = layer(                                                                                                                                                                                  
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                  
    return self._call_impl(*args, **kwargs)                                                                                                                                                                           
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                          
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/home/project_path/mamba/mamba_ssm/modules/block.py", line 67, in forward                                                                                            
    hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)                                                                                                                      
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                  
    return self._call_impl(*args, **kwargs)                                                                                                                                                                           
  File "/home/venvs/project_path/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                          
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/home/project_path/mamba/mamba_ssm/modules/mha.py", line 288, in forward                                                                                             
    context = self._update_kvcache_attention(q, kv, inference_params)                                                                                                                                                 
  File "/home/project_path/mamba/mamba_ssm/modules/mha.py", line 193, in _update_kvcache_attention                                                                           
    kv_cache[:, :, 0],                                                                                                                                                                                                
TypeError: tuple indices must be integers or slices, not tuple

Note that I installed both lm-eval and mamba as an editable module in a virtual environment.

haileyschoelkopf commented 1 month ago

What is your mamba_ssm version? are you able to load+evaluate this model outside of lm-eval-harness?

And does evaluating Mamba-2 models work for you? If they don't, for now we should pin the mamba_ssm version.

jhuang265 commented 1 month ago

I'm using mamba_ssm version 2.2.2 (I installed it in editable form but the latest release before I installed it was this one).

Evaluating all other mamba models works fine with the lm-evaluation-harness. I've been able to generate text with state-spaces/transformerpp-2.7b without the harness just fine (through explicitly calling model.generate(...)), so I assume there's some minor compatibility issue with this specific model that is related to how the harness calls the generation method.

haileyschoelkopf commented 2 weeks ago

Going to close this under the assumption that it appears to be something that needs handling on the mamba_ssm side of things, sorry!