HazyResearch / based

Code for exploring Based models from "Simple linear attention language models balance the recall-throughput tradeoff"
Apache License 2.0
203 stars 13 forks source link

Type Error in GPTLMHeadModel #3

Open axelmagn opened 6 months ago

axelmagn commented 6 months ago

I am having a go at running inference and evaluation for this model, and running into a TypeError in GPTLMHeadModel:

In [1]: import torch
   ...: from transformers import AutoTokenizer
   ...: from based.models.gpt import GPTLMHeadModel
   ...: 
   ...: tokenizer = AutoTokenizer.from_pretrained("gpt2")
   ...: model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float
   ...: 16)
tokenizer_config.json: 100%|███████████████████████████████████████████| 26.0/26.0 [00:00<00:00, 260kB/s]
config.json: 100%|██████████████████████████████████████████████████████| 665/665 [00:00<00:00, 8.64MB/s]
vocab.json: 100%|███████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 12.1MB/s]
merges.txt: 100%|█████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 8.99MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 17.8MB/s]
config.json: 100%|██████████████████████████████████████████████████| 2.86k/2.86k [00:00<00:00, 36.7MB/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 6
      3 from based.models.gpt import GPTLMHeadModel
      5 tokenizer = AutoTokenizer.from_pretrained("gpt2")
----> 6 model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)

File /based/models/gpt.py:468, in GPTPreTrainedModel.from_pretrained_hf(cls, pretrained_model_name, device, **kwargs)
    466 config_data = load_config_hf(pretrained_model_name)
    467 config = GPT2Config(**config_data)
--> 468 model = cls(config, device=device, **kwargs)
    469 state_dict = load_state_dict_hf(pretrained_model_name, device=device)
    471 # remove the 'model.' prefix from the keys

File /based/models/gpt.py:741, in GPTLMHeadModel.__init__(self, config, process_group, device, dtype)
    739 super().__init__(config)
    740 self.process_group = process_group
--> 741 self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
    742 self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
    743 lm_head_bias = getattr(config, "lm_head_bias", False)

File /based/models/gpt.py:585, in GPTModel.__init__(self, config, process_group, device, dtype)
    569     self.embeddings = ParallelGPT2Embeddings(
    570         config.hidden_size,
    571         vocab_size,
   (...)
    575         **factory_kwargs,
    576     )
    578 # We change the order of dropout, residual and layer norm:
    579 # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
    580 # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
    581 # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
    582 # nn.Dropout probabilities are changed.
    583 # This is for performance reason: we can fuse dropout + add + layer_norm.
    584 self.layers = nn.ModuleList(
--> 585     [
    586         create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
    587         for i in range(config.num_hidden_layers)
    588     ]
    589 )
    590 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
    591 if self.fused_dropout_add_ln:

File /based/models/gpt.py:586, in <listcomp>(.0)
    569     self.embeddings = ParallelGPT2Embeddings(
    570         config.hidden_size,
    571         vocab_size,
   (...)
    575         **factory_kwargs,
    576     )
    578 # We change the order of dropout, residual and layer norm:
    579 # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
    580 # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
    581 # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
    582 # nn.Dropout probabilities are changed.
    583 # This is for performance reason: we can fuse dropout + add + layer_norm.
    584 self.layers = nn.ModuleList(
    585     [
--> 586         create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
    587         for i in range(config.num_hidden_layers)
    588     ]
    589 )
    590 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
    591 if self.fused_dropout_add_ln:

File /based/models/gpt.py:371, in create_block(config, layer_idx, process_group, device, dtype, **kwargs)
    369 mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    370 use_rms_norm = getattr(config, "rms_norm", False)
--> 371 norm_cls = partial(
    372     nn.LayerNorm if not use_rms_norm else RMSNorm,
    373     eps=config.layer_norm_epsilon,
    374     **factory_kwargs,
    375 )
    376 # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
    377 residual_in_fp32 = getattr(config, "residual_in_fp32", False)

TypeError: the first argument must be callable

For reproducibility, I have been running this in a docker container:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && apt-get install -y \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2

RUN pip install jupyter
COPY . .
RUN pip install .

Any idea what could be going wrong here?

simran-arora commented 6 months ago

Hi, I think it's because this RMSNorm is being set to None https://github.com/HazyResearch/based/blob/e8de5648f7e84248be8ebc1499e817641b0f577b/based/models/gpt.py#L371

Due to the import structure here: https://github.com/HazyResearch/based/blob/e8de5648f7e84248be8ebc1499e817641b0f577b/based/models/gpt.py#L52

The options are to

Sorry for the difficulty -- we will fix the install / instructions for this

axelmagn commented 6 months ago

No worries, and thanks for the speedy reply. Your guidance helped me get past the above error by installing the norm from flash-attn, but there seem to be more undocumented dependency issues:

root@d75213223120:/app# python3 test_script.py 
tokenizer_config.json: 100%|██████████████████████████████████████████████| 26.0/26.0 [00:00<00:00, 149kB/s]
config.json: 100%|█████████████████████████████████████████████████████████| 665/665 [00:00<00:00, 8.40MB/s]
vocab.json: 100%|██████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 9.97MB/s]
merges.txt: 100%|████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 28.5MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 9.81MB/s]
config.json: 100%|█████████████████████████████████████████████████████| 2.86k/2.86k [00:00<00:00, 35.0MB/s]
No module named 'causal_attention_cuda'
Traceback (most recent call last):
  File "/app/test_script.py", line 6, in <module>
    model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)
  File "/app/based/models/gpt.py", line 468, in from_pretrained_hf
    model = cls(config, device=device, **kwargs)
  File "/app/based/models/gpt.py", line 741, in __init__
    self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  File "/app/based/models/gpt.py", line 585, in __init__
    [
  File "/app/based/models/gpt.py", line 586, in <listcomp>
    create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
  File "/app/based/models/gpt.py", line 382, in create_block
    block = Block(
  File "/app/based/models/block.py", line 86, in __init__
    self.mixer = mixer_cls(dim)
  File "/app/based/models/mixers/slide_attention.py", line 357, in __init__
    if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed")
ImportError: fused_dense is not installed

I'm a little baffled, since it seems like FusedDense is being imported from flash_attn here:

https://github.com/HazyResearch/based/blob/e8de5648f7e84248be8ebc1499e817641b0f577b/based/models/mixers/slide_attention.py#L29

Are there additional subpackages within flash-attn that need to be installed?

For reference, here is my updated dockerfile:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && apt-get install -y \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2 \
    packaging==23.2 \
    ninja==1.11.1.1 

# RUN pip install 'git+https://github.com/Dao-AILab/flash-attention.git@6c9e60d' 
RUN pip install 'git+https://github.com/Dao-AILab/flash-attention.git@6c9e60d#subdirectory=csrc/layer_norm'

# install apex
RUN pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'

# install based
RUN mkdir -p /app
WORKDIR /app
COPY . .
RUN pip install .

CMD python3 test_script.py
simran-arora commented 6 months ago

That line you pointed out requires this to be installed: https://github.com/Dao-AILab/flash-attention/tree/main/csrc/fused_dense_lib

Would recommend cloning flash-attention and python setup.py install within this directory An alternative workaround, without the install, is to, in the config, set fused_bias_fc = False

axelmagn commented 6 months ago

After some tweaking, I think I've got it working. I ended up using the HazyReasearch/flash-attention fork. For others trying via docker, this is the dockerfile I used:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

ARG TORCH_CUDA_ARCH_LIST="8.0+PTX"

RUN apt-get update && apt-get install -y \
    build-essential \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2 \
    packaging==23.2 \
    ninja==1.11.1.1 

# install apex
RUN pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'

RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2' --no-build-isolation
RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/fused_dense_lib'  --no-build-isolation
RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/layer_norm' --no-build-isolation

# install based
RUN mkdir -p /app
WORKDIR /app
COPY . .
RUN pip install .

CMD python3 test_script.py

It requires NVIDIA docker tookit to run, with the command:

docker run --rm --runtime=nvidia --gpus all based
melisa-writer commented 5 months ago

Hi! I got a similar problem while running the sample code:

import torch
from transformers import AutoTokenizer
from based.models.gpt import GPTLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)

input = tokenizer.encode("If I take one more step, it will be", return_tensors="pt").to("cuda")
output = model.generate(input, max_length=20)
print(tokenizer.decode(output[0]))

Error:

Traceback (most recent call last):
  File "/home/melisarussak/based/inference_test.py", line 6, in <module>
    model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)
  File "/home/melisarussak/based/based/models/gpt.py", line 470, in from_pretrained_hf
    model = cls(config, device=device, **kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 743, in __init__
    self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 587, in __init__
    [
  File "/home/melisarussak/based/based/models/gpt.py", line 588, in <listcomp>
    create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 373, in create_block
    norm_cls = partial(
TypeError: the first argument must be callable

so I used the Dockerfile given by @axelmagn and now I get:

No module named 'causal_attention_cuda'
Successfully imported the causal dot product kernel!
Could not import the FLA triton kernels...
Traceback (most recent call last):
  File "/app/inference_test.py", line 9, in <module>
    output = model.generate(input, max_length=20)
  File "/app/based/generation.py", line 573, in generate
    output = decode(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/app/based/generation.py", line 194, in decode
    scores.append(get_logits(sequences[-1], inference_params))
  File "/app/based/generation.py", line 155, in get_logits
    logits = model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/gpt.py", line 806, in forward
    hidden_states = self.transformer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/gpt.py", line 674, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/block.py", line 189, in forward
    hidden_states = self.mixer(hidden_states, position_ids=position_ids, decay=decay, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/mixers/linear_attention.py", line 127, in forward
    return self.recurrent_forward(hidden_states, kv_state, k_state, q, k, v)
  File "/app/based/models/mixers/linear_attention.py", line 195, in recurrent_forward
    kv_state += k[:, :, -1:] * v[:, :, -1:]
RuntimeError: The size of tensor a (16) must match the size of tensor b (273) at non-singleton dimension 4

Is this due to code changes 2 days ago or I am missing some steps?

simran-arora commented 5 months ago

yes that was due to the changes, please try again and let me know if you run into issues

melisa-writer commented 5 months ago

it works now! 🎉 thank you!