johnsmith0031 / alpaca_lora_4bit

MIT License
533 stars 84 forks source link

Use gradient checkpoint only for training mode, not evaluation #128

Closed alex4321 closed 1 year ago

alex4321 commented 1 year ago

Assume we have the following training script structure (copy-pasted my script):

model, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)
model_to_half(model)

wrapper = AMPWrapper(model)
wrapper.apply_forward()

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)

lora_model = get_peft_model(model, lora_config)
lora_model = lora_model_zeros_and_scales_to_half(lora_model)

apply_gradient_checkpointing(lora_model, checkpoint_ratio=1);

...

training_arguments = TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=5,
    optim="adamw_torch",
    num_train_epochs=10,
    learning_rate=3e-4,
    fp16=True,
    logging_steps=20,
    evaluation_strategy="epoch",
    save_strategy="no",
    output_dir="lora-output-directory",
    save_total_limit=3,
    load_best_model_at_end=False,
    ddp_find_unused_parameters=False,
    report_to="none",
    label_names=["labels"],
)
trainer = IterableDatasetTrainer(
    memory_lora_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_arguments,
)
lora_model.config.use_cache = False

Don't pay much attention to classes / models, they all are subclasses of standard transformers stuff with minimal changes,

The key issue is here:

apply_gradient_checkpointing(lora_model, checkpoint_ratio=1)
training_arguments = TrainingArguments(
    ...
    evaluation_strategy="epoch",
    ...
    output_dir="lora-output-directory",
    ...
)

so we have gradient checkpointing enabled - and we have evaluation enabled in the same time.

So when huggingface trainer loop do evaluation - it switch model to model.eval mode and use torch,no_grad() or something like so, so input tensor does not need gradients.

Which cause the following code to make warnings:

class NewForward:

    def __init__(self, layer):
        self.layer = layer
        self.apply_patch()

    def apply_patch(self):
        self.layer.old_forward_for_cp = self.layer.forward
        self.layer.forward = self.new_forward

    def new_forward(self, *args, **kwargs):
        def func(*args):
            return self.layer.old_forward_for_cp(*args, **kwargs)
        output = checkpoint(func, *args)
        return output

warnings like this:

warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

So I suggest to replace this code with the following one:

class NewForward:

    def __init__(self, layer):
        self.layer = layer
        self.apply_patch()

    def apply_patch(self):
        self.layer.old_forward_for_cp = self.layer.forward
        self.layer.forward = self.new_forward

    def new_forward(self, *args, **kwargs):
        def func(*args):
            return self.layer.old_forward_for_cp(*args, **kwargs)

        if self.layer.training:
            output = checkpoint(func, *args)
        else:
            output = func(*args)
        return output
alex4321 commented 1 year ago

@johnsmith0031 what do you think about this way to fix these warnings? Not sure it's really ideal, because:

johnsmith0031 commented 1 year ago

Thanks, I think we can try this fix first.