huggingface / transformers

đŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.92k stars 26.27k forks source link

Out of memory when using phi3 for token classification #32632

Open xinyudong93 opened 1 month ago

xinyudong93 commented 1 month ago

System Info

I'm using AWS sagemaker to implement a token classification model using phi3

Who can help?

No response

Information

Tasks

Reproduction

I got the OOM error message, when I try to import phi3-mini for token classification, even I use 1 sample per batch, it still exists.

model =  Phi3ForTokenClassification.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
num_labels=9,
id2label=id2label,
label2id=label2id,
use_cache=False,
torch_dtype=torch.bfloat16
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["up_proj",
    "gate_proj",
    "k_proj",
    "q_proj",
    "v_proj",
    "down_proj",
    "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.TOKEN_CLS,
)

model = get_peft_model(model, peft_config)

tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct',max_seq_length = 96)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

output_dir = "/opt/ml/model/"

training_args = TrainingArguments(
    output_dir=output_dir,
    bf16=args.bf16,
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    num_train_epochs=1,
    weight_decay=0.01,
    report_to = "tensorboard",
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=50,
)

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

Expected behavior

Can you help me to solve it or suggest what might be the problems.

amyeroberts commented 1 month ago

Hi, thanks for raising an issue!

This is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

To get the best help I'd suggest sharing as much technical information as possible including all relevant error information e.g. full stack trace, observed memory utilization; and relevant technical information e.g. hardware being used.

To help debug if it's related to your setup or the model, you can try running with a small model and see if it's successful.