huggingface / optimum-habana

Easy and lightning fast training of 🤗 Transformers on Habana Gaudi processor (HPU)
Apache License 2.0
148 stars 187 forks source link

LlamaForCausalLM.forward() got an unexpected keyword argument 'use_flash_attention' #760

Open dittops opened 7 months ago

dittops commented 7 months ago

System Info

optimum-habana==1.10.4
docker: vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest

Information

Tasks

Reproduction

I'm trying to enable flash attention by setting model.generation_config.use_flash_attention = True. But I'm getting below error.

    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1521, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1530, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'use_flash_attention'

Expected behavior

Expected to use flash attention without any issue

regisss commented 7 months ago

Please share your script and the command to run it here, that makes investigation much easier and you'll get a solution much faster. My best guess without this information is that you don't call the adapt_transformers_to_gaudi method in your script or you call it after importing some classes from Transformers.

dittops commented 7 months ago

I'm not calling that function in my script. I was following the example here to enable flash attn. https://github.com/huggingface/optimum-habana/blob/main/examples/language-modeling/run_lora_clm.py

Here is my train script

import pickle
import os
from dataclasses import dataclass, field
from typing import Optional
from itertools import chain

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, HfArgumentParser, DataCollatorForSeq2Seq     
from datasets import load_dataset
from optimum.habana import GaudiTrainer, GaudiTrainingArguments

# from model.llama import convert_llama_model, convert_llama_with_temperature

IGNORE_INDEX = -100

@dataclass
class ModelArguments:
    base_model: Optional[str] = field(default="base-model")
    use_lambda: Optional[bool] = field(default=False)
    temperature: Optional[float] = field(default=1)
    use_flash_attention: Optional[bool] = field(default=False)
    flash_attention_recompute: Optional[bool] = field(default=False)

@dataclass
class DataArguments:
    data_path: str = field(
      default=None, metadata={"help": "Path to the training data."}
    )
    max_seq_length = 2048
    is_tokenized: Optional[bool] = field(default=False)
    system_col: Optional[str] = field(default="")
    input_col: Optional[str] = field(default="input")
    target_col: Optional[str] = field(default="target")

@dataclass
class TrainArguments(GaudiTrainingArguments):
    per_device_train_batch_size = 2
    gradient_accumulation_steps = 1
    num_train_epochs = 3
    learning_rate = 2e-5
    fp16 = True
    logging_steps = 10
    optim = "adamw_torch"
    save_strategy = "epoch"
    output_dir = 'bud-1b'
    save_total_limit = 5
    report_to = 'wandb'
    adam_beta1 = 0.9
    adam_beta2 = 0.95
    stage: Optional[str] = field(default="pretrain")

def get_prompt(input, system = '', target = ''):
    # if not system:
    #     system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."

    prompt_template = f"""{system}

    ### Instruction:
    {input}

    ### Response:
    {target}"""

    return prompt_template

def load_data(tokenizer, dataset, max_length, stage, system_col, input_col, target_col):

    def preprocess_pretrain_dataset(examples):

        text_ids = tokenizer(
            examples["text"],
            add_special_tokens=False)["input_ids"]

        concatenated_ids = list(chain(*text_ids))
        total_length = len(concatenated_ids)

        block_size = max_length
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
        total_length = (total_length // block_size) * block_size
        # split by chunks of max_source_length
        result = [concatenated_ids[i: i + block_size]
                  for i in range(0, total_length, block_size)]

        return {
            "input_ids": result,
            "labels": result.copy()
        }

    def preprocess_supervised_dataset(examples):
        #data format = {"prompt":"", "response": "", "history": []}

        model_inputs = {"input_ids": [], "labels": [], "attention_mask": []}

        for i in range(len(examples[input_col])):
            input_ids, labels = [], []
            if system_col:
                source = get_prompt(examples[input_col][i], examples[system_col][i])
            else:
                source = get_prompt(examples[input_col][i])
            source_ids = tokenizer.encode(text=source, add_special_tokens=False)
            target_ids = tokenizer.encode(text=examples[target_col][i], add_special_tokens=False)
            input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
            labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]

            if len(input_ids) > max_length:
                input_ids = input_ids[:max_length]
                labels = labels[:max_length]

            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)
            model_inputs["attention_mask"].append([1] * len(input_ids))

        return model_inputs

    def print_supervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
        print("label_ids:\n{}".format(example["labels"]))
        print("labels:\n{}".format(tokenizer.decode([
            token_id if token_id != -100 else tokenizer.pad_token_id for token_id in example["labels"]
        ], skip_special_tokens=False)))

    if stage == 'pretrain':
        map_func = preprocess_pretrain_dataset
    else:
        map_func = preprocess_supervised_dataset

    column_names = dataset.column_names
    dataset = dataset.map(
        map_func,
        batched=True,
        remove_columns=column_names,
        num_proc=64
    )

    # print_supervised_dataset_example(dataset[0])
    print(len(dataset))

    return {
        "train_dataset": dataset
    }

def train():

    os.environ["WANDB_PROJECT"] = 'gaudi-pretrain-exp'

    # convert_llama_with_temperature()

    parser = HfArgumentParser(
        (ModelArguments, DataArguments, TrainArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    tokenizer = AutoTokenizer.from_pretrained(model_args.base_model, trust_remote_code=True)
    tokenizer.pad_token_id = 0

    model = AutoModelForCausalLM.from_pretrained(
        model_args.base_model,
        # device_map="auto",
        torch_dtype=torch.float16,
        # use_flash_attention_2=True,
        # attn_implementation="flash_attention_2",
        trust_remote_code=True
    )

    if model_args.use_flash_attention:
        model.generation_config.use_flash_attention = True
        model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute

    # if model_args.use_lambda:
    #     local_branch = data_args.max_seq_length
    #     global_branch = 100
    #     model = convert_llama_model(model, local_branch, global_branch)
    #     print('Added lambda attention')

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("total params: ", total_params)

    if data_args.data_path.endswith(".json") or data_args.data_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_args.data_path)
    elif data_args.data_path.endswith(".pkl"):
        with open(data_args.data_path, "rb") as file:
            data = pickle.load(file)
    else:
        data = load_dataset(data_args.data_path)

    if data_args.is_tokenized:
        dataset = {
            "train_dataset": data["train"]
        }
    else:
        with training_args.main_process_first(desc="dataset map tokenization"):
            dataset =  load_data(
                tokenizer, 
                data['train'], 
                data_args.max_seq_length, 
                training_args.stage, 
                data_args.system_col,
                data_args.input_col,
                data_args.target_col
            )

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

    trainer = GaudiTrainer(
        model=model, 
        tokenizer=tokenizer, 
        args=training_args, 
        data_collator=data_collator,
        **dataset
    )
    # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
    #     trainer.train(resume_from_checkpoint=True)
    # else:
    #     trainer.train()
    model.config.use_cache = False
    trainer.train(resume_from_checkpoint=False)

    trainer.save_model()

if __name__ == "__main__":
    train()
regisss commented 7 months ago

Thanks. What is the command you use to run this script?

dittops commented 7 months ago

Here is the command

deepspeed train-gaudi.py --base_model budecosystem/boomer-1b --output_dir output/boomer --data_path roneneldan/TinyStories --learning_rate 1e-3 --num_train_epochs 1 --per_device_train_batch_size 2 --gradient_accumulation_steps 1 --lr_scheduler_type cosine --warmup_ratio 0.1 --report_to wandb --logging_steps 10 --save_strategy steps --save_steps 10000 --save_total_limit 2 --use_habana --gaudi_config_name gaudi_config.json --deepspeed ds_config.json --use_flash_attention True
regisss commented 7 months ago

I cannot reproduce it, it works on my side. Can you provide the full logs of your run and the output of pip list please?

dittops commented 7 months ago

Thanks for your support. For some reason, I'm able to run the script without any issues now.

I have another question, does this flash attention have the same effect as the official implementation? I was able to run a llama 634m parameter model full parameter training with 10 per device batch size in Nvidia A100 80 GB. But here I can only run up to 8 per device batch size

regisss commented 7 months ago

Are you using Gaudi1 or Gaudi2?

dittops commented 7 months ago

I'm using Gaudi2

regisss commented 7 months ago

Can you share the logs of your run please? For a 634M-parameter model, you should be able to fit much bigger batches on Gaudi2.