Open esnvidia opened 1 year ago
Did you install the dropout_layer_norm
extension in this repo (cd csrc/layer_norm && pip install .
)?
If not then you should set config.fused_dropout_add_ln = False
and config.residual_in_fp32 = False
.
I did not, I figured it was installed via the pip install flash-attn
.
I re-ran with:
config.fused_dropout_add_ln = False #True
config.residual_in_fp32 = False #True
and get the same error:
the first argument must be callable
File "/workspace/tformer/llama.py", line 157, in <module>
model = GPTLMHeadModel(config=config, device=device, dtype=dtype)
TypeError: the first argument must be callable
What's the line that gives that error? Can you put a breakpoint and check what's being called there?
Oh wait if you didn't install that extension then RMSnorm = None
, so it wouldn't work I don't think.
Yes it's here:
flash_attn/models/gpt.py", line 171, in create_block
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm, ...
TypeError: the first argument must be callable
For LLama you need to install that extension to use RMSNorm
.
Ok, so what are the appropriate install cmds to get all the extensions?
Would be good to have clarity on what pip install flash-attn
includes. Or maybe include something like pip install flash-attn[full]
to just install it all.
Readme seems to suggest an either-or. Either install via pip or (alternatively) install from src.
You can see https://github.com/HazyResearch/flash-attention/tree/main/training for the installation cmd.
Idk how to make flash-attn[full]
work, if you have pointers I'd appreciate that.
I get the TypeError below when training Llama from scratch.
I followed the example in the tests for llama and so here's an example code with a custom tokenizer to train Llama from scratch w/ HuggingFace.
The error occurs in
model = GPTLMHeadModel(config=config, device=device, dtype=dtype)
Later on, I'd like to setbf16=True, bf16_full_eval=True
in theTrainingArguments
but also noticed that the defaults from theconfig
set bf16 to False when the TrainingArguments bf16=True and bf16_full_eval=True (not shown below).Using transformers= 4.30.2 installed in the NGC Pytorch 23.06 container. Here's an example Dockerfile to replicate my env.
Here's the output of
llama_config
:and
config