allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.24k stars 399 forks source link

Break at 1 epoch "Training epoch complete", can't pretraining beyond 1 epoch ? #554

Open Xuekai-Zhu opened 3 months ago

Xuekai-Zhu commented 3 months ago

πŸ› Describe the bug

File :OLMo/olmo/train.py In the following training loop, we will break our pre-training for only 1 epoch ?

@property
def max_epochs(self) -> int:
    if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
        return int(self.cfg.max_duration[:-2].strip())
    else:
        return 1
with torch_profiler as p:
            for epoch in range(self.epoch or 0, self.max_epochs):
                for batch in self.train_loader:
                    # Bookkeeping.
                    # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
                    # batches see the same number of tokens, which should be the case for language model pre-training
                    # (at least when drop_last=True).
                    # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
                    # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
                    # fail loudly.
                    batch_size, seq_len = batch["input_ids"].shape
                    assert seq_len == self.cfg.model.max_sequence_length
                    assert batch_size == self.cfg.device_train_batch_size
                    global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
                    self.global_step += 1
                    self.global_train_examples_seen_this_epoch += global_batch_size
                    self.global_train_tokens_seen += global_batch_size * seq_len
                    speed_monitor.batch_start(
                        self.global_train_tokens_seen,
                        batch_size * seq_len,  # num tokens in batch for this device
                        # We start monitoring speed after the first batch since the first
                        # batch might be an outlier due to compiling and other initialization overhead.
                        record=not first_batch,
                    )

                    should_log_this_step = self.should_log_this_step()

                    # Run train step on batch.
                    metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)

                    # Maybe collect other metrics.
                    if should_log_this_step:
                        # Speed metrics.
                        metrics.update(speed_monitor.check())
                        # System metrics.
                        metrics.update(self.system_metrics())
                        # Learning rate metrics.
                        metrics.update(lr_monitor.check())

                    # Log metrics to console.
                    if self.global_step % self.cfg.console_log_interval == 0:
                        self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

                    # Log metrics to W&B.
                    if (
                        wandb.run is not None
                        and self.cfg.wandb is not None
                        and self.global_step % self.cfg.wandb.log_interval == 0
                    ):
                        wandb.log(metrics, step=self.global_step)

                    # Check if/when run should be canceled.
                    if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                        cancel_initiated, extra_steps = self.check_if_cancelled()
                        if cancel_initiated:
                            stop_at = (
                                self.global_step + extra_steps
                                if stop_at is None
                                else min(self.global_step + extra_steps, stop_at)
                            )

                    # Maybe save sharded checkpoint.
                    if save_checkpoints and (
                        cancel_initiated
                        or (
                            self.global_step % self.cfg.save_interval == 0
                            and self.cfg.save_num_checkpoints_to_keep != 0
                        )
                    ):
                        log.info("Saving checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Remove any ephemeral checkpoints.
                        while self.ephemeral_checkpoints:
                            self.remove_ephemeral_checkpoint()

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                        # If the run was just canceled this will be the final checkpoint.
                        if cancel_initiated:
                            save_checkpoints = False
                    elif (
                        self.cfg.save_interval_ephemeral is not None
                        and self.global_step % self.cfg.save_interval_ephemeral == 0
                    ):
                        log.info("Saving ephemeral checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe save unsharded checkpoint.
                    if (
                        save_checkpoints
                        and self.cfg.save_interval_unsharded is not None
                        and self.global_step % self.cfg.save_interval_unsharded == 0
                        and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
                    ):
                        log.info("Saving unsharded checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
                        log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe run evaluations.
                    if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
                        eval_metrics = self.eval()

                        # Log metrics to W&B.
                        if wandb.run is not None:
                            wandb.log(eval_metrics, step=self.global_step)

                        # Reset speed monitor so that we don't count the time taken to run evaluations.
                        speed_monitor.reset()

                        # Reset model to 'train' mode.
                        self.fsdp_model.train()

                    # End of batch.
                    first_batch = False
                    if p is not None:
                        p.step()

                    if stop_at is not None and self.global_step >= stop_at:
                        break

                    # Python Profiler stuff
                    # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
                    if python_profiler is not None:
                        if self.global_step == 5:
                            python_profiler.enable()
                        elif self.global_step == 8:
                            python_profiler.disable()
                            python_profiler.print_stats(sort=SortKey.CUMULATIVE)
                            python_profiler = None
                else:
                    log.info("Training epoch complete")
                    self.epoch = epoch + 1
                    self.global_train_examples_seen_this_epoch = 0
                    if self.epoch < self.max_epochs:
                        self.dataset.reshuffle()
                    continue

                break

Versions

Python 3.10.13 WARNING: Could not find a Python project for directory /scratch2/nlp/zhuxuekai/scaling_law4AI_data/OLMo (tried all parent directories) -e git+ssh://git@github.com/Xuekai-Zhu/scaling_law4AI_data.git@a15301e68a4dd616e3971c54370cb4a957e4d14c#egg=ai2_olmo aiohttp==3.9.3 aiosignal==1.3.1 aniso8601==9.0.1 annotated-types==0.6.0 antlr4-python3-runtime==4.9.3 anykeystore==0.2 appdirs==1.4.4 async-timeout==4.0.3 asyncio==3.4.3 attrs==23.2.0 backports.tarfile==1.1.0 beaker-gantry==0.22.2 beaker-py==1.26.4 black==23.12.1 blinker==1.7.0 boltons==24.0.0 boto3==1.34.86 botocore==1.34.86 build==1.2.1 cached_path==1.6.2 cachetools==5.3.3 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 click-help-colors==0.9.4 cmake==3.28.3 contourpy==1.2.0 cryptacular==1.6.2 cryptography==42.0.5 cycler==0.12.1 datasets==2.18.0 deepspeed==0.14.0 deepspeed-kernels==0.0.1.dev1698255861 deepspeed-mii==0.2.3 defusedxml==0.7.1 dill==0.3.8 docker==6.1.3 docker-pycreds==0.4.0 docutils==0.21.1 exceptiongroup==1.2.0 face==20.1.1 filelock==3.9.0 Flask==3.0.2 Flask-RESTful==0.3.10 fonttools==4.50.0 frozenlist==1.4.1 fsspec==2024.2.0 ftfy==6.2.0 gitdb==4.0.11 GitPython==3.1.42 glom==23.5.0 google-api-core==2.18.0 google-auth==2.29.0 google-cloud-core==2.4.1 google-cloud-storage==2.16.0 google-crc32c==1.5.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 greenlet==3.0.3 grpcio==1.62.1 grpcio-tools==1.62.1 hjson==3.1.0 huggingface-hub==0.21.4 hupper==1.12.1 idna==3.6 importlib_metadata==7.1.0 iniconfig==2.0.0 isort==5.12.0 itsdangerous==2.1.2 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.0 jeepney==0.8.0 Jinja2==3.1.2 jmespath==1.0.1 joblib==1.4.0 keyring==25.1.0 kiwisolver==1.4.5 lightning-utilities==0.11.2 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.8.3 mdurl==0.1.2 Megatron==0.5.1 megatron_core==0.5.0 more-itertools==10.2.0 mpmath==1.3.0 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.16 mypy==1.3.0 mypy-extensions==1.0.0 necessary==0.4.3 networkx==3.2.1 nh3==0.2.17 ninja==1.11.1.1 numpy==1.26.4 nvidia-cublas-cu11==11.11.3.6 nvidia-cuda-cupti-cu11==11.8.87 nvidia-cuda-nvrtc-cu11==11.8.89 nvidia-cuda-runtime-cu11==11.8.89 nvidia-cudnn-cu11==8.7.0.84 nvidia-cufft-cu11==10.9.0.58 nvidia-curand-cu11==10.3.0.86 nvidia-cusolver-cu11==11.4.1.48 nvidia-cusparse-cu11==11.7.5.86 nvidia-nccl-cu11==2.19.3 nvidia-nvtx-cu11==11.8.86 oauthlib==3.2.2 omegaconf==2.3.0 packaging==24.0 pandas==2.2.1 PasteDeploy==3.1.0 pathspec==0.12.1 pbkdf2==1.3 petname==2.6 pillow==10.2.0 pkginfo==1.10.0 plaster==1.1.2 plaster-pastedeploy==1.0.1 platformdirs==4.2.0 pluggy==1.4.0 proto-plus==1.23.0 protobuf==4.25.3 psutil==5.9.8 py-cpuinfo==9.0.0 pyarrow==15.0.2 pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycparser==2.22 pydantic==2.6.4 pydantic_core==2.16.3 Pygments==2.17.2 pynvml==11.5.0 pyparsing==3.1.2 pyproject_hooks==1.0.0 pyramid==2.0.2 pyramid-mailer==0.15.1 pytest==8.1.1 pytest-sphinx==0.6.3 python-dateutil==2.9.0.post0 python3-openid==3.2.0 pytz==2024.1 PyYAML==6.0.1 pyzmq==25.1.2 readme_renderer==43.0 regex==2023.12.25 repoze.sendmail==4.4.1 requests==2.31.0 requests-oauthlib==2.0.0 requests-toolbelt==1.0.0 requirements-parser==0.9.0 rfc3986==2.0.0 rich==13.7.1 rsa==4.9 ruff==0.3.7 s3transfer==0.10.1 safetensors==0.4.2 scikit-learn==1.4.2 scipy==1.13.0 seaborn==0.13.2 SecretStorage==3.3.3 sentry-sdk==1.43.0 setproctitle==1.3.3 six==1.16.0 smart-open==7.0.4 smashed==0.21.5 smmap==5.0.1 SQLAlchemy==2.0.29 sympy==1.12 threadpoolctl==3.4.0 tokenizers==0.15.2 tomli==2.0.1 torch==2.2.1+cu118 torchmetrics==1.3.2 tqdm==4.66.2 transaction==4.0 transformers==4.38.2 translationstring==1.4 triton==2.2.0 trouting==0.3.3 twine==5.0.0 types-setuptools==69.5.0.20240415 typing_extensions==4.8.0 tzdata==2024.1 ujson==5.9.0 urllib3==2.2.1 velruse==1.1.1 venusian==3.1.0 wandb==0.16.4 wcwidth==0.2.13 WebOb==1.8.7 websocket-client==1.7.0 Werkzeug==3.0.1 wrapt==1.16.0 WTForms==3.1.2 wtforms-recaptcha==0.3.2 xxhash==3.4.1 yarl==1.9.4 zipp==3.18.1 zmq==0.0.0 zope.deprecation==5.0 zope.interface==6.3 zope.sqlalchemy==3.1

dumitrac commented 2 months ago

@Xuekai-Zhu , what is the value of "max_duration" in the config that you're using? If you want it to be more than 1 epoch, say 2 epochs, the config should have max_duration: 2ep.

Xuekai-Zhu commented 2 months ago

Yes, i found if i want it to be more than 1 epoch, the config should have max_duration: 2ep. But when i want use max tokens to control the the training process, i can't reach the max tokens casuing be limited by default 1 epochs.

source tokens 8B, max_duration:  30B, -> training complete at 8B tokens (1 epochs); 
❌ can't reach the max_duration set in config.
dumitrac commented 2 months ago

@Xuekai-Zhu - agreed, this is a bug. Thank you for reporting it.