Closed peterniu19 closed 4 months ago
hi and thanks for using EasyDeL
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
config_kwargs={"attn_mechanism":"sharded_vanilla",
'max_position_embeddings':max_length,'bits':4},
sharding_axis_dims=(1,1,1,-1),
input_shape=(1,max_length),
precision=None)
use this one
Thank you. It works. I'm confused about how changing bits will help in reducing TPU memory during training. I can train a model successfuly without adding the "bits" parameter. But if I added the bits parameter to 4 or 8, it says tpu memory exhausted. Could you please provide some insights about it?
yes, actually it works better in bigger scale, right now your getting out of memory error because you are trying to use gradient checkpointing, but imagine you are using much more TPUs and you have enough memory to chose to have benefit of lower precision operation instead of trying checkpointing gradients of operations, or modules.
Thank you for the explanation. I reduced the batch size and did an experiment by comparing scenarios with and without using bits=8, while keeping all other settings constant. I observed that both the speed and TPU utilization (monitored via wandb with use_wandb=True) remained same. I'm wondering why these parameters did not vary between the two experiments.
Actually Tpu monitoring is just a tool to findout how much model and buffer taking up memory, so it doesn't record forward and backward memory monitoring, you can do that too and that's not difficult but there high chance that your training loop will crash.
Hello, amazing work! I tried lora finetune using tpu on kaggle. However, when i set the bits to 4. It below error happened. Further, when I add the precision to jax.lax.Precision("default"). The error still exists. Could tell me how to use 4bit lora training with easydel?