state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.44k stars 1.05k forks source link

HuggingFace trainer #108

Open lolofo opened 7 months ago

lolofo commented 7 months ago

Hello,

I'm trying to fine-tune the mamba model with a huggingface trainer but I'm facing an issue : AttributeError: 'MambaConfig' object has no attribute 'to_json_string'

This is due to the fact that the MambaConfig does not follow the classical huggingface format for the configurations. The MambaConfig is a dataclass which is at the origin of this error.

jyegerlehner commented 7 months ago

Here is an example script that trains Mamba with huggingface transformer library.

Haven't tried it yet or looked closely, but I'd speculate they side-step the issue by overriding the save_model method in their MambaTrainer subclass of Trainer.

lolofo commented 7 months ago

Thank you I have implemented a very similar trainer ... I'll have a very close look to this one, thank you !

jyegerlehner commented 7 months ago

One other thing: it is fairly trivial to extend MambaConfig to add a to_json_string method. This PR includes that change.

lolofo commented 7 months ago

I have added the method to the config, and it works thank you !

RonanKMcGovern commented 7 months ago

Any tips on how to reduce VRAM requirements?

I'm training the 2.8B Mamba and I'm oom on 16k context on an A100 80GB. Batch size of 1.

I guess the huggingface trainer is probably materializing h in high bandwidth GPU memory, and that has to be stored for each input token for the current layer that is being worked on? So that's what's making memory requirements high...?

lqf0624 commented 5 months ago

Any tips on how to reduce VRAM requirements?

I'm training the 2.8B Mamba and I'm oom on 16k context on an A100 80GB. Batch size of 1.

I guess the huggingface trainer is probably materializing h in high bandwidth GPU memory, and that has to be stored for each input token for the current layer that is being worked on? So that's what's making memory requirements high...?

I'm not focus on NLP tasks with mamba, but i encountered the problem too. I don't know how it works but i reinstalled the mamba-ssm library and the VRAM has been decreased so much. With the same training settings, it may only need about 15GB VRAM, which would rise OOM before the lib is installed. I don't know if it can help you, just give you an idea to solve it.

RonanKMcGovern commented 5 months ago

Thanks appreciate that, yeah I'll try that next time - reinstalling mamba-ssm

On Mon, Mar 11, 2024 at 11:32 AM lqf0624 @.***> wrote:

Any tips on how to reduce VRAM requirements?

I'm training the 2.8B Mamba and I'm oom on 16k context on an A100 80GB. Batch size of 1.

I guess the huggingface trainer is probably materializing h in high bandwidth GPU memory, and that has to be stored for each input token for the current layer that is being worked on? So that's what's making memory requirements high...?

I'm not focus on NLP tasks with mamba, but i encountered the problem too. I don't know how it works but i reinstalled the mamba-ssm library and the VRAM has been decreased so much. With the same training settings, it may only need about 15GB VRAM, which would rise OOM before the lib is installed. I don't know if it can help you, just give you an idea to solve it.

— Reply to this email directly, view it on GitHub https://github.com/state-spaces/mamba/issues/108#issuecomment-1988231616, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CXJTM7LU6L43RIMT5DYXWI5PAVCNFSM6AAAAABB5DGYT6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBYGIZTCNRRGY . You are receiving this because you are subscribed to this thread.Message ID: @.***>

lurchyy commented 3 months ago

@lolofo how exactly did you add the method to the config?

lolofo commented 3 months ago

@lurchyy I did something like this

class MambaCustomConfig(MambaConfig): """ custom config to make the model run with HF Trainer """ def to_json_string(self,): return json.dumps( { "d_model" : int(self.d_model), "n_layer" : int(self.n_layer), "vocab_size" : int(self.vocab_size), "ssm_config" : self.ssm_cfg, "rms_norm" : self.rms_norm, "residual_in_fp32" : self.residual_in_fp32, "fused_add_norm" : self.fused_add_norm, "pad_vocab_size_multiple" : self.pad_vocab_size_multiple } )