mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
4.02k stars 525 forks source link

Problem with saving FSDP model #295

Closed MFajcik closed 1 year ago

MFajcik commented 1 year ago

Environment

Collecting system information...
---------------------------------
System Environment Report        
Created: 2023-06-07 09:44:28 CEST
---------------------------------

PyTorch information
-------------------
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.17

Python version: 3.9.5 (default, Jun  4 2021, 12:28:51)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.90.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.13.1
[pip3] torch-optimizer==0.3.0
[pip3] torchmetrics==0.11.3
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] torch-optimizer           0.3.0                    pypi_0    pypi
[conda] torchmetrics              0.11.3                   pypi_0    pypi
[conda] torchtext                 0.14.1                   pypi_0    pypi
[conda] torchvision               0.14.1                   pypi_0    pypi

Composer information
--------------------
Composer version: 0.14.1
Composer commit hash: None
Host processor model name: AMD EPYC 7763 64-Core Processor
Host processor core count: 128
Number of nodes: 1
Accelerator model name: NVIDIA A100-SXM4-40GB
Accelerators per node: 8
CUDA Device Count: 8

To reproduce

Steps to reproduce the behavior:

  1. Use following config:
    
    max_seq_len: 2048
    global_seed: 17

Run Name

run_name: TEST_LARGE_RUN # If left blank, will be read from env var $COMPOSER_RUN_NAME

Model

model: name: hf_causal_lm pretrained_model_name_or_path: mosaicml/mpt-7b init_device: cpu pretrained: true trust_remote_code: true config_overrides:

WARNING: if setting pretrained: true, max_position_embeddings must match the

# `max_position_embeddings` used during pre-training
n_positions: ${max_seq_len}
attn_config:
  attn_impl: torch
  alibi: true

dataset: &hf_dataset hf_name: .data/testual_dataset decoder_only_format: true max_seq_len: ${max_seq_len}

Tokenizer

tokenizer: name: EleutherAI/gpt-neox-20b kwargs: model_max_length: ${max_seq_len}

Dataloaders

train_loader: name: finetuning dataset: <<: *hf_dataset split: train shuffle: true shuffle_seed: ${global_seed} drop_last: true num_workers: 2

eval_loader: name: finetuning dataset: <<: *hf_dataset split: test shuffle: false shuffle_seed: ${global_seed} drop_last: false num_workers: 2

Optimization

scheduler: name: linear_decay_with_warmup t_warmup: 50ba # 100 batches alpha_f: 0.1

optimizer: name: decoupled_lionw lr: 1.2e-4 betas:

algorithms: gradient_clipping: clipping_type: norm clipping_threshold: 1.0

max_duration: 2000ba # ~ 134B tokens eval_interval: 200ba eval_first: false eval_subset_num_batches: -1 global_train_batch_size: 64

System

seed: ${global_seed} device_eval_batch_size: 1 device_train_microbatch_size: 1

device_train_microbatch_size: auto

precision: amp_bf16

FSDP

fsdp_config: sharding_strategy: FULL_SHARD mixed_precision: PURE activation_checkpointing: true activation_checkpointing_reentrant: false activation_cpu_offload: false limit_all_gathers: true verbose: false

Logging

progress_bar: true log_to_console: true console_log_interval: 1ba

callbacks: speed_monitor: window_size: 10 lr_monitor: { } memory_monitor: { } runtime_estimator: { }

loggers:

wandb: {}

Checkpoint to local filesystem or remote object store

save_interval: 1ba save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK save_folder: ./mpt_saved

3.

Create dataset:
I've just made a jsonl file with repeated contents
```json
{"prompt": "Hello", "response": "world"}
{"prompt": "Hello", "response": "world"}

and simple dataset script

# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import json
import os

import datasets

_CITATION = ""

_DESCRIPTION = """\
Test task
"""

_HOMEPAGE = ""

_LICENSE = ""

class TestualDataset(datasets.GeneratorBasedBuilder):
    VERSION = datasets.Version("1.1.0")

    BUILDER_CONFIGS = [
        datasets.BuilderConfig(name="text", version=VERSION, description=""),
    ]

    def _info(self):
        features = datasets.Features(
            {
                "prompt": datasets.Value("string"),
                "response": datasets.Value("string"),
            }
        )

        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=_DESCRIPTION,
            # This defines the different columns of the dataset and their types
            features=features,  # Here we define them above because they are different between the two configurations
            # If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and
            # specify them. They'll be used if as_supervised=True in builder.as_dataset.
            # supervised_keys=("sentence", "label"),
            # Homepage of the dataset for documentation
            homepage=_HOMEPAGE,
            # License for the dataset if available
            license=_LICENSE,
            # Citation for the dataset
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": os.path.join(".data/testual_dataset.jsonl"),
                    "split": "train",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": os.path.join(".data/testual_dataset.jsonl"),
                    "split": "test"
                },
            ),
        ]

    # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
    def _generate_examples(self, filepath, split):
        with open(filepath, encoding="utf-8") as f:
            for key, row in enumerate(f):
                data = json.loads(row)
                yield key, data

Unfortunately, I am getting some strange error when unsharding the checkpoint (end of STOUD + STDERR from rank 5)

... training stuff, all was working OK until saving ...
Expects full precision but got torch.bfloat16

----------End global rank 5 STDOUT----------
----------Begin global rank 5 STDERR----------

  File "/projectdir/MPT_LLMFoundry/scripts/train/train.py", line 255, in <module>
    main(cfg)
  File "/projectdir/MPT_LLMFoundry/scripts/train/train.py", line 244, in main
    trainer.fit()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/trainer/trainer.py", line 17
66, in fit
    self._train_loop()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/trainer/trainer.py", line 19
96, in _train_loop
    self.engine.run_event(Event.BATCH_CHECKPOINT)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 293, i
n run_event
    self._run_nonlogger_callbacks(event)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 475, i
n _run_nonlogger_callbacks
    self._run_callbacks(event, callbacks)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 467, i
n _run_callbacks
    cb.run_event(event, self.state, self.logger)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/callback.py", line 96, 
in run_event
    return event_cb(state, logger)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.p
y", line 346, in batch_checkpoint
    self._save_checkpoint(
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.p
y", line 384, in _save_checkpoint
    saved_path = checkpoint.save_checkpoint(
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/utils/checkpoint.py", line 5
18, in save_checkpoint
    'state': state.state_dict(),
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/state.py", line 789, in
 state_dict
    model_state = attribute_value.state_dict()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 144
8, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2402, in state_dict
    with summon_ctx:
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/contextlib.py", line 117, in __enter__
    return next(self.gen)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2981, in _summon_full_params
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2981, in <listcomp>
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py"
, line 681, in needs_unshard
    unsharded_flat_param = self._get_padded_unsharded_flat_param()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py"
, line 714, in _get_padded_unsharded_flat_param
    p_assert(
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/_utils.py", li
ne 147, in p_assert
    traceback.print_stack()
Traceback (most recent call last):
  File "/projectdir/MPT_LLMFoundry/scripts/train/train.py", line 255, in <module>
    main(cfg)
  File "/projectdir/MPT_LLMFoundry/scripts/train/train.py", line 244, in main
    trainer.fit()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/trainer/trainer.py", line 17
66, in fit
    self._train_loop()
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/trainer/trainer.py", line 19
96, in _train_loop
    self.engine.run_event(Event.BATCH_CHECKPOINT)
  File "/envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 293, i
n run_event
 self._run_nonlogger_callbacks(event)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 475, $
n _run_nonlogger_callbacks
    self._run_callbacks(event, callbacks)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/engine.py", line 467, $
n _run_callbacks
    cb.run_event(event, self.state, self.logger)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/callback.py", line 96, 
in run_event
    return event_cb(state, logger)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.p
y", line 346, in batch_checkpoint
    self._save_checkpoint(
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/callbacks/checkpoint_saver.p
y", line 384, in _save_checkpoint
    saved_path = checkpoint.save_checkpoint(
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/utils/checkpoint.py", line 5
18, in save_checkpoint
    'state': state.state_dict(),
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/composer/core/state.py", line 789, in
 state_dict
    model_state = attribute_value.state_dict()
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 144
8, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2402, in state_dict
    with summon_ctx:
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/contextlib.py", line 117, in __enter__
    return next(self.gen)
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2981, in _summon_full_params
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_
data_parallel.py", line 2981, in <listcomp>
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py"
, line 681, in needs_unshard
    unsharded_flat_param = self._get_padded_unsharded_flat_param()
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py"
, line 714, in _get_padded_unsharded_flat_param
    p_assert(
  File "envdir/environment/miniconda/envs/mpt/lib/python3.9/site-packages/torch/distributed/fsdp/_utils.py", li
ne 149, in p_assert
    raise AssertionError
AssertionError

----------End global rank 5 STDERR----------

The last commit of foundry version I am using was commit 8ba68fb3442bbd373a75fbab29e5dfc26862195c (HEAD -> main, origin/main, origin/HEAD) Date: Tue May 9 20:04:38 2023 -0700

Expected behavior

FSDP saving wouldn't crash the training...

Obviously it is some typing error, but I spent several hours debugging it today, and I am not sure what is casted wrong where... ANy help is appreciated.

Cheers, Martin

MFajcik commented 1 year ago

Sorry, my bad I think this was a result of my code modification, which I checked I removed. (But some sshfs delay happened, and my test code executed without that code part removed! :( ). It works now.