mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
4.03k stars 526 forks source link

getting OOM on 8 nvidia GPUs with 40GB memory each #54

Closed arpitg1991 closed 1 year ago

arpitg1991 commented 1 year ago

using scripts/train/train.py yaml: yamls/mpt/finetune/7b_dolly_sft.yaml

abhi-mosaic commented 1 year ago

Hey there! Thanks for trying out our repo.

Could you share any diffs you made to the YAML, or is it exactly as in the repo? Tagging @alextrott16 as the finetuning expert. We will reproduce on our side.

For an immediate fix, could you try lowering device_train_microbatch_size from 8 to say, 4, and see if that helps?

alextrott16 commented 1 year ago

Hi!

Memory usage should be the same with finetuning as with pre-training. It basically will just depend on the model size, device_train_microbatch_size, max_seq_len and mode.attn_config.attn_impl. I guess with FSDP sharding it'll also depend on the number of devices. 8 should be enough with 80GBs of GPU RAM, but might not be enough with 40GBs.

In any case, lowering device_train_microbatch_size to 4 is your best bet :)

arpitg1991 commented 1 year ago

i am using the exact same config. A little unclear how to use the huggingface model i am adding this at the end of yaml load_path: /root/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/

i get the cuda error in data loading stage itself , when i reduce seq_len to low value like 100 it goes away but then it happens during model loading stage nvidia-smi gives this

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:10:1C.0 Off |                    0 |
| N/A   26C    P0    51W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:10:1D.0 Off |                    0 |
| N/A   26C    P0    51W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:20:1C.0 Off |                    0 |
| N/A   27C    P0    49W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:20:1D.0 Off |                    0 |
| N/A   26C    P0    52W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  On   | 00000000:90:1C.0 Off |                    0 |
| N/A   26C    P0    50W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM...  On   | 00000000:90:1D.0 Off |                    0 |
| N/A   27C    P0    55W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM...  On   | 00000000:A0:1C.0 Off |                    0 |
| N/A   26C    P0    50W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM...  On   | 00000000:A0:1D.0 Off |                    0 |
| N/A   25C    P0    49W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
arpitg1991 commented 1 year ago

even doing this doesnt help

global_train_batch_size: 8 # assuming 8 gpus

# System
seed: ${global_seed}
device_eval_batch_size: 1
device_train_microbatch_size: 1
# device_train_microbatch_size: auto
precision: amp_bf16
arpitg1991 commented 1 year ago

@vchiley

alextrott16 commented 1 year ago

I'll admit to being a little puzzled by this. I was just able to run cd llm-foundry/scripts/train && composer train.py yamls/mpt/finetune/7b_dolly_sft.yaml on 1 node with 8 A100-40GB.

Note: I did have to overwrite the load_path: ... line because it is just a placeholder and, unless you went through the pre-training steps to create a pre-trained checkpoint, you won't want to use load_path.
To do that, use the command: composer train.py yamls/mpt/finetune/7b_dolly_sft.yaml load_path=null

To help me better understand the issue, a couple questions:

arpitg1991 commented 1 year ago

i am using the docker image now. still getting error

root@042b5478a3f4:~/llm-foundry/scripts/train# python train.py yamls/mpt/finetune/7b_dolly_sft.yaml load_path=null
Initializing model...
cfg.n_params=6.66e+09
Building train loader...
Using pad_token, but it is not set yet.
Re-formatting dataset with "HuggingFaceH4/databricks_dolly_15k" preprocessing function.
Found cached dataset parquet (/root/.cache/huggingface/datasets/HuggingFaceH4___parquet/HuggingFaceH4--databricks_dolly_15k-6252f3495e7d2b9d/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/HuggingFaceH4___parquet/HuggingFaceH4--databricks_dolly_15k-6252f3495e7d2b9d/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-9cf4c020b7f570d5.arrow
Building eval loader...
Building trainer...
Traceback (most recent call last):
  File "/root/llm-foundry/scripts/train/train.py", line 254, in <module>
    main(cfg)
  File "/root/llm-foundry/scripts/train/train.py", line 197, in main
    trainer = Trainer(
  File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 954, in __init__
    prepare_fsdp_module(model, optimizers, self.fsdp_config, precision, device, auto_microbatching)
  File "/usr/lib/python3/dist-packages/composer/trainer/dist_strategy.py", line 355, in prepare_fsdp_module
    fsdp_obj = MosaicFullyShardedDataParallel(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1062, in __init__
    self._materialize_module(module, param_init_fn, device_from_device_id)
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1368, in _materialize_module
    param_init_fn(module)
  File "/usr/lib/python3/dist-packages/composer/trainer/dist_strategy.py", line 283, in _param_init_fn
    meta_safe_apply(module,
  File "/usr/lib/python3/dist-packages/composer/trainer/meta_safe_apply.py", line 36, in meta_safe_apply
    meta_safe_apply(module, fn, ignored_modules, curr_module_name)
  File "/usr/lib/python3/dist-packages/composer/trainer/meta_safe_apply.py", line 36, in meta_safe_apply
    meta_safe_apply(module, fn, ignored_modules, curr_module_name)
  File "/usr/lib/python3/dist-packages/composer/trainer/meta_safe_apply.py", line 36, in meta_safe_apply
    meta_safe_apply(module, fn, ignored_modules, curr_module_name)
  [Previous line repeated 1 more time]
  File "/usr/lib/python3/dist-packages/composer/trainer/meta_safe_apply.py", line 60, in meta_safe_apply
    param_applied = fn(param)
  File "/usr/lib/python3/dist-packages/composer/trainer/dist_strategy.py", line 284, in <lambda>
    lambda t: torch.empty_like(t, device=f'cuda:{torch.cuda.current_device()}'),
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 770.00 MiB (GPU 0; 39.56 GiB total capacity; 38.31 GiB already allocated; 202.56 MiB free; 38.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
alextrott16 commented 1 year ago

I think I see the problem. This looks like the same error I was getting when using python instead of composer.

Can you run composer train.py yamls/mpt/finetune/7b_dolly_sft.yaml load_path=null instead?

abhi-mosaic commented 1 year ago

Adding on to @alextrott16 's comment:

When you are using python train.py ..., you are launching just one process, which will run training on one GPU. In that situation, FSDP is not doing anything, and your 7B model is not sharded, and the weights+optimizer+activation memory must fit on 1xA100, and it is running into an OOM.

When you use the composer launcher such as composer train.py ..., it detects the number of GPUs on your machine, launching 8 processes to target the 8 GPUs, and the 7B model gets sharded across 8 GPUs, and now the memory should fit per-GPU and it should work!

arpitg1991 commented 1 year ago

i got a bit further. but there was issue with cuda and pytorch compatibility. i have cu11.8 so installed pytorch 2.0.1 and triton 2.0.0 But now seeing this: should i try with cuda 11.7 ?

  runtime_estimator: {}
load_path: null
dist_timeout: 600.0
n_gpus: 8
device_train_batch_size: 8
device_train_grad_accum: 1
n_params: 6658859008

Starting training...
******************************
Config:
enabled_algorithms/GradientClipping: true
node_name: unknown because NODENAME environment variable not set
num_gpus_per_node: 8
num_nodes: 1
rank_zero_seed: 17

******************************
/home/ec2-user/llm-foundry/llmfoundry/data/finetuning/collator.py:188: UserWarning: Truncating TARGET sequence of length=849 to length=812, so context+target fit max_seq_len=2048. If truncation is a problem, consider increasing max_seq_len.
  warnings.warn(
Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, True, True, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
    generator.visit(fn.parse())
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
    self.visit_compound_statement(node.body)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 319, in visit_AugAssign
    self.visit(assign)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 301, in visit_Assign
    values = self.visit(node.value)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/compiler.py", line 797, in visit_Call
    return fn(*args, _builder=self.builder, **kws)
  File "/home/ec2-user/miniconda3/lib/python3.10/site-packages/triton/impl/base.py", line 22, in wrapper
    return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ec2-user/llm-foundry/scripts/train/train.py", line 254, in <module>
    main(cfg)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^
ERROR:composer.cli.launcher:Rank 2 crashed with exit code 1.
Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.
Global rank 0 (PID 45705) exited with code 1
Global rank 2 (PID 45707) exited with code 1
----------Begin global rank 2 STDOUT----------
Initializing model...
cfg.n_params=6.66e+09
Building train loader...
Re-formatting dataset with "HuggingFaceH4/databricks_dolly_15k" preprocessing function.
Building eval loader...
Building trainer...
Logging config...
max_seq_len: 2048
global_seed: 17
run_name: llm
model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 4096
  n_heads: 32
  n_layers: 32
  expansion_ratio: 4
  max_seq_len: ${max_seq_len}
  vocab_size: 50368
  attn_config:
    attn_impl: triton
tokenizer:
  name: EleutherAI/gpt-neox-20b
  kwargs:
    model_max_length: ${max_seq_len}
train_loader:
  name: finetuning
  dataset:
    hf_name: HuggingFaceH4/databricks_dolly_15k
    split: train
    max_seq_len: ${max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    shuffle: true
  drop_last: true
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0
scheduler:
  name: linear_decay_with_warmup
  t_warmup: 0ba
  alpha_f: 0
optimizer:
  name: decoupled_adamw
  lr: 1.0e-05
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-08
  weight_decay: 0
algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0
max_duration: 1ep
eval_interval: 1
global_train_batch_size: 64
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
precision: amp_bf16
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)                                                                                                                                                                 [0/1864]
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^

----------End global rank 7 STDERR----------
abhi-mosaic commented 1 year ago

Hi @arpitg1991 , this repo is currently only tested with torch 1.13.1 (deps here) and it will be a few weeks before we fully transition to torch2 as there are many moving pieces (Composer, Streaming, FSDP).

I would highly recommend using our public Docker image listed in the README (mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04), this is the image the MosaicML NLP team uses to do all of our work, and it comes with flash-attn and triton preinstalled.

arpitg1991 commented 1 year ago

ok, i ran into disk space issues when i ran with docker, let me try to fix that

abhi-mosaic commented 1 year ago

Hi @arpitg1991 , I am closing this issue for cleanup but please add a comment / file a new issue if you need additional help. Cheers!