Break at 1 epoch "Training epoch complete", can't pretraining beyond 1 epoch ? #554

Open Xuekai-Zhu opened 3 months ago

Xuekai-Zhu commented 3 months ago

πŸ› Describe the bug

File :OLMo/olmo/ In the following training loop, we will break our pre-training for only 1 epoch ?

def max_epochs(self) -> int:
    if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
        return int(self.cfg.max_duration[:-2].strip())
        return 1
with torch_profiler as p:
            for epoch in range(self.epoch or 0, self.max_epochs):
                for batch in self.train_loader:
                    # Bookkeeping.
                    # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
                    # batches see the same number of tokens, which should be the case for language model pre-training
                    # (at least when drop_last=True).
                    # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
                    # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
                    # fail loudly.
                    batch_size, seq_len = batch["input_ids"].shape
                    assert seq_len == self.cfg.model.max_sequence_length
                    assert batch_size == self.cfg.device_train_batch_size
                    global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
                    self.global_step += 1
                    self.global_train_examples_seen_this_epoch += global_batch_size
                    self.global_train_tokens_seen += global_batch_size * seq_len
                        batch_size * seq_len,  # num tokens in batch for this device
                        # We start monitoring speed after the first batch since the first
                        # batch might be an outlier due to compiling and other initialization overhead.
                        record=not first_batch,

                    should_log_this_step = self.should_log_this_step()

                    # Run train step on batch.
                    metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)

                    # Maybe collect other metrics.
                    if should_log_this_step:
                        # Speed metrics.
                        # System metrics.
                        # Learning rate metrics.

                    # Log metrics to console.
                    if self.global_step % self.cfg.console_log_interval == 0:
                        self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

                    # Log metrics to W&B.
                    if (
               is not None
                        and self.cfg.wandb is not None
                        and self.global_step % self.cfg.wandb.log_interval == 0
                        wandb.log(metrics, step=self.global_step)

                    # Check if/when run should be canceled.
                    if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                        cancel_initiated, extra_steps = self.check_if_cancelled()
                        if cancel_initiated:
                            stop_at = (
                                self.global_step + extra_steps
                                if stop_at is None
                                else min(self.global_step + extra_steps, stop_at)

                    # Maybe save sharded checkpoint.
                    if save_checkpoints and (
                        or (
                            self.global_step % self.cfg.save_interval == 0
                            and self.cfg.save_num_checkpoints_to_keep != 0
              "Saving checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
              "Checkpoint saved to {checkpoint_path}")

                        # Remove any ephemeral checkpoints.
                        while self.ephemeral_checkpoints:

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.

                        # If the run was just canceled this will be the final checkpoint.
                        if cancel_initiated:
                            save_checkpoints = False
                    elif (
                        self.cfg.save_interval_ephemeral is not None
                        and self.global_step % self.cfg.save_interval_ephemeral == 0
              "Saving ephemeral checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
              "Checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.

                    # Maybe save unsharded checkpoint.
                    if (
                        and self.cfg.save_interval_unsharded is not None
                        and self.global_step % self.cfg.save_interval_unsharded == 0
                        and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
              "Saving unsharded checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
              "Unsharded checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.

                    # Maybe run evaluations.
                    if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
                        eval_metrics = self.eval()

                        # Log metrics to W&B.
                        if is not None:
                            wandb.log(eval_metrics, step=self.global_step)

                        # Reset speed monitor so that we don't count the time taken to run evaluations.

                        # Reset model to 'train' mode.

                    # End of batch.
                    first_batch = False
                    if p is not None:

                    if stop_at is not None and self.global_step >= stop_at:

                    # Python Profiler stuff
                    # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
                    if python_profiler is not None:
                        if self.global_step == 5:
                        elif self.global_step == 8:
                            python_profiler = None
          "Training epoch complete")
                    self.epoch = epoch + 1
                    self.global_train_examples_seen_this_epoch = 0
                    if self.epoch < self.max_epochs:



dumitrac commented 2 months ago

@Xuekai-Zhu , what is the value of "max_duration" in the config that you're using? If you want it to be more than 1 epoch, say 2 epochs, the config should have max_duration: 2ep.

Xuekai-Zhu commented 2 months ago

Yes, i found if i want it to be more than 1 epoch, the config should have max_duration: 2ep. But when i want use max tokens to control the the training process, i can't reach the max tokens casuing be limited by default 1 epochs.

source tokens 8B, max_duration:  30B, -> training complete at 8B tokens (1 epochs); 
❌ can't reach the max_duration set in config.
dumitrac commented 2 months ago

@Xuekai-Zhu - agreed, this is a bug. Thank you for reporting it.