Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.28k stars 1.34k forks source link

RMS Norm TypeError: the first argument must be callable when training Llama from scratch #302

Open esnvidia opened 1 year ago

esnvidia commented 1 year ago

I get the TypeError below when training Llama from scratch.

model = GPTLMHeadModel(config=config, device=device, dtype=dtype)
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/models/gpt.py", line 407, in __init__
    self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/models/gpt.py", line 310, in __init__
    self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/models/gpt.py", line 310, in <listcomp>
    self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/models/gpt.py", line 171, in create_block
    norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
TypeError: the first argument must be callable

I followed the example in the tests for llama and so here's an example code with a custom tokenizer to train Llama from scratch w/ HuggingFace.

import torch
from transformers  import LlamaConfig
from flash_attn.models import llama
from flash_attn.models.gpt import GPTLMHeadModel
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

training_args = TrainingArguments(
    output_dir="./llama",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=2,
    # bf16=True,
    # bf16_full_eval=True,
    bf16=False,
    bf16_full_eval=False,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False,)
llama_config = LlamaConfig(vocab_size=tok.vocab_size,
                           hidden_size=1024,
                           num_hidden_layers=2,
                           num_attention_heads=32,
                           bos_token_id=tok.bos_id,
                           eos_token_id=tok.eos_id,

                           )
# below parts taken from here: https://github.com/HazyResearch/flash-attention/blob/72ad03eaa661f6bf3a14c855316c27fbab4f8f4c/tests/models/test_llama.py#L122
dtype = torch.float16
device = 'cuda'
config = llama.llama_config_to_gpt2_config(llama_config=llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False  # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True

model = GPTLMHeadModel(config=config, device=device, dtype=dtype)

# continuing w/ HF Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tds["train"],
    eval_dataset=tds["valid"],
    tokenizer=tok,
    data_collator=data_collator,
)  

trainer.train()

The error occurs in model = GPTLMHeadModel(config=config, device=device, dtype=dtype) Later on, I'd like to set bf16=True, bf16_full_eval=True in the TrainingArguments but also noticed that the defaults from the config set bf16 to False when the TrainingArguments bf16=True and bf16_full_eval=True (not shown below).

Using transformers= 4.30.2 installed in the NGC Pytorch 23.06 container. Here's an example Dockerfile to replicate my env.

FROM nvcr.io/nvidia/pytorch:23.06-py3

RUN pip install flash-attn transformers[torch] datasets

Here's the output of llama_config:

llama_config
LlamaConfig {
  "bos_token_id": 123456789,
  "eos_token_id": 123456789,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "transformers_version": "4.30.2",
  "use_cache": true,
  "vocab_size": 123456790
}

and config

config
GPT2Config {
  "activation_function": "swiglu",
  "attn_pdrop": 0.0,
  "bos_token_id": 123456789,
  "embd_pdrop": 0.0,
  "eos_token_id": 123456789,
  "fused_bias_fc": true,
  "fused_dropout_add_ln": true,
  "fused_mlp": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-06,
  "mlp_fc1_bias": false,
  "mlp_fc2_bias": false,
  "model_type": "gpt2",
  "n_embd": 1024,
  "n_head": 32,
  "n_inner": 11008,
  "n_layer": 2,
  "n_positions": 0,
  "out_proj_bias": false,
  "pad_token_id": 0,
  "qkv_proj_bias": false,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.0,
  "residual_in_fp32": true,
  "rms_norm": true,
  "rotary_emb_fraction": 1.0,
  "rotary_emb_interleaved": true,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.30.2",
  "use_cache": true,
  "use_flash_attn": true,
  "vocab_size": 123456790
}
tridao commented 1 year ago

Did you install the dropout_layer_norm extension in this repo (cd csrc/layer_norm && pip install .)? If not then you should set config.fused_dropout_add_ln = False and config.residual_in_fp32 = False.

esnvidia commented 1 year ago

I did not, I figured it was installed via the pip install flash-attn.

I re-ran with:

config.fused_dropout_add_ln = False #True
config.residual_in_fp32 = False #True

and get the same error:

the first argument must be callable
  File "/workspace/tformer/llama.py", line 157, in <module>
    model = GPTLMHeadModel(config=config, device=device, dtype=dtype)
TypeError: the first argument must be callable
tridao commented 1 year ago

What's the line that gives that error? Can you put a breakpoint and check what's being called there?

tridao commented 1 year ago

Oh wait if you didn't install that extension then RMSnorm = None, so it wouldn't work I don't think.

esnvidia commented 1 year ago

Yes it's here:

flash_attn/models/gpt.py", line 171, in create_block
    norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm, ...
TypeError: the first argument must be callable
tridao commented 1 year ago

For LLama you need to install that extension to use RMSNorm.

esnvidia commented 1 year ago

Ok, so what are the appropriate install cmds to get all the extensions?

Would be good to have clarity on what pip install flash-attn includes. Or maybe include something like pip install flash-attn[full] to just install it all.

Readme seems to suggest an either-or. Either install via pip or (alternatively) install from src.

tridao commented 1 year ago

You can see https://github.com/HazyResearch/flash-attention/tree/main/training for the installation cmd. Idk how to make flash-attn[full] work, if you have pointers I'd appreciate that.