huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.97k stars 1.26k forks source link

finetuning gemma2-2b with multi-gpu get OOM, how do i only do model sharding and no data parallel(i guess it's going into DDP). #2019

Closed bhupendrathore closed 1 month ago

bhupendrathore commented 2 months ago

System Info

Name: transformers Version: 4.45.0.dev0 Name: trl Version: 0.8.6

Information

Tasks

Reproduction

my code : https://gist.github.com/bhupendrathore/b750a2d9307c6c5b8ee94e54daad97e5

i am trying to use trl for multi gpu training (motive is to distribute the model among all 4 gpu), i end up getting OOM. earlier i finetuned mistral 7b with bigger max-len and it was working fine (when mistral 7b was released.) now it;s even failing with 2b model.

the reason i'm using


device_map={"": PartialState().process_index},

as device map auto fails with :


ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on.

i understand device_map={"": PartialState().process_index}, would require the model to fit in one gpu and that's why its getting OOM. I have 4 A100s with 40GB VRAM each.

any help or direction would be much appreciated .. i also believe that gemma-2-2b's might have some leakage too. the other day i was running a model.generate 1 by one which was failing but when i ran each iteration with torch.cuda.empty_cache() it worked.

is there any way like before i load the model in different gpus and do normal finetuning without any data parallel.

device map auto seems to be not working or anything on that would be helpful.

Expected behavior

sshould finetune without OOM. as it's majorly using 0th gpu and others are underutilzed.

bhupendrathore commented 2 months ago

i could use Naive pipeline parallelism with custom device map. like below

I tried many including zero but as per nvidia-smi i kept changing to balance.

device_map = {'model.embed_tokens': 1,
 'model.layers.0': 1,
 'model.layers.1': 1,
 'model.layers.2': 1,
 'model.layers.3': 1,
 'model.layers.4': 1,
 'model.layers.5': 1,
 'model.layers.6': 1,
 'model.layers.7': 1,
 'model.layers.8': 1,
 'model.layers.9': 1,
 'model.layers.10': 2,
 'model.layers.11': 2,
 'model.layers.12': 2,
 'model.layers.13': 2,
 'model.layers.14': 2,
 'model.layers.15': 2,
 'model.layers.16': 2,
 'model.layers.17': 2,
 'model.layers.18': 2,
 'model.layers.19': 2,
 'model.layers.20': 2,
 'model.layers.21': 2,
 'model.layers.22': 2,
 'model.layers.23': 2,
 'model.layers.24': 2,
 'model.layers.25': 3,
 'model.norm': 3,
 'lm_head': 1}

more details : https://github.com/huggingface/blog/blob/main/accelerate-large-models.md

this also solves the problem below

ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on.

but the problem of OOM still remains, the code run fine with smaller context length though but i still believe that it should be doable with this much ram (4 x A100 40GB ). any thoughts.

I tried with some blocks on cpu devices as well with arg 'llm_int8_enable_fp32_cpu_offload=True' but i guess ValueError: You can't train a model that has been loaded in 8-bit precision with CPU or "disk offload".

RylanSchaeffer commented 2 months ago

What context length does your dataset have?

bhupendrathore commented 1 month ago

8000 tokens (gemma2-2b tokenizer) max, infact it's even failing with 4096.

bhupendrathore commented 1 month ago

i think it's not problem with trl. it's a problem still but not sure why happening.. shouldn't be getting OOM.