mlfoundations / open_lm

A repository for research on medium sized language models.
MIT License
479 stars 69 forks source link

NotImplementedError running HF model "mlfoundations/dclm-7b-it" for inference #303

Open neginraoof opened 2 months ago

neginraoof commented 2 months ago

I am trying to use the HF model "mlfoundations/dclm-7b-it" for inference, simply using the code below:

model = AutoModelForCausalLM.from_pretrained("mlfoundations/dclm-7b-it")
gen_kwargs = {"max_new_tokens": 500, "temperature": 0}
output = model.generate(inputs['input_ids'], **gen_kwargs)

I see this warning when loading the model: Some weights of OpenLMForCausalLM were not initialized from the model checkpoint at mlfoundations/dclm-7b-it and are newly initialized: [...]

And I get NotImplementedError:

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(1, 3, 32, 128) (torch.float32)
     key         : shape=(1, 3, 32, 128) (torch.float32)
     value       : shape=(1, 3, 32, 128) (torch.float32)
     attn_bias   : <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>
     p           : 0.0

I have also tried model = AutoModel.from_pretrained("mlfoundations/dclm-7b-it"), but this model class also fails with ValueError: Unrecognized configuration class.

Which model class should I use here?

sedrick-keh-tri commented 2 months ago

This is usually an xformers issue. I think the main issue is that xformers doesn't run on CPU, so the quick short-term fix is to make sure you send all your models and tensors to device/GPU. That should resolve the issue.

I think the long-term solution here would probably be to get rid of xformers entirely. You can do this locally by setting "attn_name": "torch_attn" and "ffn_type": "swiglu_torch". I know the Apple models and TRI models do this, but I guess the mlfoundations one wasn't updated accordingly. I'm putting in a PR now.