Lightning-AI / litdata

Streamline data pipelines for AI. Process datasets across 1000s of machines, and optimize data for blazing fast model training.
Apache License 2.0
250 stars 24 forks source link

Pytorch lighting Fabric + lit data + DDP hangs when finishing epoch #129

Open miguelalba96 opened 1 month ago

miguelalba96 commented 1 month ago

šŸ› Bug

I am training CLIP using pytorch lighing fabric + litdata on a distributed set up (4 nodes each 4 GPUs). I noticed that when finishing the 1st epoch the training dataloaders hang for some nodes.

The image bellow shows fabric.print()doing the logging on 4 nodes before finishing an epoch (I print every 25 steps). Only one rank successfully finishes the rest hang, otherwise the message ++++ Epoch: 0 completed ++++ will appear 4 times, once in each node). I shared relevant parts of the code bellow, any help would be appreciated

image

Additional context + Parts of Training code

StreamingDataset setup

I am using the following script to load the litdata:

class ImageCaptionDataset(StreamingDataset):
    """
    Main dataset to retrieve imager-caption optimized dataset
    """

    @staticmethod
    def get_zero_shot_one_hot(zero_shot_attributes: List[int]):
        one_hot_encoded = torch.zeros(len(CLASSES), dtype=torch.float)
        one_hot_encoded[zero_shot_attributes] = 1.0
        return one_hot_encoded

    def __getitem__(self, idx: int) -> Any:
        _, image_bytes, text_ids, mask, zero_shot_attr = super().__getitem__(idx)
        image = Image.open(io.BytesIO(image_bytes))
        input_ids = torch.tensor(np.frombuffer(text_ids, dtype=np.int64))
        attention_mask = torch.tensor(np.frombuffer(mask, dtype=np.int64))
        zero_shot_attr = self.get_zero_shot_one_hot(literal_eval(zero_shot_attr))
        return image, input_ids, attention_mask, zero_shot_attr

def collate_fn(batch, processor):
    """Arrange the batch into a dictionary of tensors for HF
    """
    images = processor(images=[ex[0] for ex in batch], return_tensors="pt")
    return {
        "pixel_values": images["pixel_values"],
        "input_ids": torch.stack([ex[1] for ex in batch]),
        "attention_mask": torch.stack([ex[2] for ex in batch]),
        "labels": torch.stack([ex[3] for ex in batch])
    }

def get_dataloader(split: str, config: configs.ExperimentConfig):
    dataset = ImageCaptionDataset(
        input_dir=os.path.join(config.local_data_path, split),
        shuffle=True if split == "train" else False,
    )
    # Image transformation function
    processor = CLIPProcessor.from_pretrained(
        config.training_config.pre_trained_backbone
    )
    return StreamingDataLoader(
        dataset,
        batch_size=config.dataset_config.batch_size,
        shuffle=True if split == "train" else False,
        num_workers=config.dataset_config.num_workers,
        collate_fn=partial(
            collate_fn,
            processor=processor
        ),
        drop_last=True,
        pin_memory=config.dataset_config.pin_memory,
    )

Am I setting properly the dataloader here?, I checked and litGPT uses torch DataLoader instead of StreamingDataloader

Here I show what I managed to monitor on how CPU and RAM looks like for an entire epoch

image

You can see how instead of jumping again to load the samples of the test set, it hangs ...

Training using Fabric

I put here some parts of my training script which basically follows open-clip implementation but using Lighting Fabric, maybe I am doing something wrong when the epoch is finishing? I noticed state is not saved at the end of the epoch on the last iteration (just in the middle of training on checkpoint_step):

def set_fabric(logger: CSVLogger, config: configs.ExperimentConfig):
    """Set fabric using VertexAI cluster environment
    """
    strategy = DDPStrategy(
        static_graph=True,
        cluster_environment=utils.infrastructure.VertexAICluster()
    )
    fabric = Fabric(
        accelerator=config.training_config.accelerator,
        strategy=strategy,
        devices=config.training_config.num_gpus,
        num_nodes=config.training_config.num_nodes,
        precision=config.training_config.precision,
        loggers=logger,
    )
    fabric.launch()
    # utils.data_streaming.set_streaming_env_vars(fabric) <- necessary for mosaicml-streaming
    return fabric

def save_model_weights(model: nn.Module, fabric: Fabric, experiment_config: configs.ExperimentConfig):
    """Save only model weights to remote storage,
     these are the ones that need to be called at inference time
    """
    if fabric.global_rank == 0:
        model.backbone.save_pretrained(experiment_config.model_path)

def save_state(
        model: nn.Module,
        state: dict,
        fabric: Fabric,
        experiment_config: configs.ExperimentConfig
):
    """
    Save state of the training session every checkpoint_step or at the end of the epoch
    """
    if state["current_step"] != 0 and (
            state["current_step"] % experiment_config.checkpoint_step == 0
            or state["iteration"] == state["num_batches_per_epoch"] - 1
    ):
        fabric.save(
            os.path.join(
                experiment_config.checkpoint_path,
                f"{state['current_epoch']:04d}-{state['current_step']:04d}-state.ckpt"
            ),
            state
        )
        save_model_weights(model, fabric, experiment_config)

def train_epoch(
        train_loader: torch.utils.data.DataLoader,
        model: nn.Module,
        state: dict,
        zero_shot_attributes: dict,
        optimizer: torch.optim.Optimizer,
        criterion: nn.Module,
        scheduler: torch.optim.lr_scheduler,
        fabric: Fabric,
        metrics_dict: dict,
        experiment_config: configs.ExperimentConfig
):
    """Train the model for one epoch (an entire training cycle)
    """
    model.train()

    accum_samples, accum_features = [], {}
    num_accumulated = 0

    for idx, batch in enumerate(train_loader, start=state["start_iteration"]):
        time_start = time.perf_counter()
        batch = fabric.to_device(batch)

        optimizer.zero_grad()

        with torch.no_grad():
            output = model(batch)
            output.pop("logit_scale", None)

            for key, value in output.items():
                if key not in accum_features:
                    accum_features[key] = [value]
                else:
                    accum_features[key].append(value)

            accum_samples.append(batch)

        num_accumulated += 1
        state["iteration"] += 1

        if (idx + 1) % experiment_config.training_config.accum_freq > 0:
            continue

        # compute embeddings on zero shot attributes
        with torch.no_grad():
            zero_shot_embeddings = model(zero_shot_attributes=zero_shot_attributes)

        # compute loss aggregating the embeddings accumulated
        optimizer.zero_grad()
        for j in range(num_accumulated):
            batch = accum_samples[j]
            output = model(batch)

            # keep scale partially
            inputs = {"logit_scale": output.pop("logit_scale")}
            for name, features in accum_features.items():
                accumulated = accum_features[name]
                inputs[name] = torch.cat(accumulated[:j] + [output[name]] + accumulated[j + 1:])

            loss = criterion(**inputs, fabric=fabric)

            del inputs
            fabric.backward(loss)

        optimizer.step()
        scheduler.step()

        # update metrics
        if check_if_is_log_step(state, experiment_config):
            update_metrics(
                "train", loss, accum_features, accum_samples, zero_shot_embeddings, metrics_dict, fabric
            )

        # reset accumulation
        accum_samples, accum_features = [], {}
        num_accumulated = 0

        with torch.no_grad():
            model.backbone.logit_scale.clamp_(0, math.log(100))

        timing = time.perf_counter() - time_start

        # log results
        if check_if_is_log_step(state, experiment_config):
            log_batch_metrics(
                timing, metrics_dict, fabric, state, scheduler.get_last_lr(), experiment_config
            )

        # save state if needed
        save_state(model, state, fabric, experiment_config)

        state["current_step"] += 1

    fabric.print(f"++++ Epoch {state['current_epoch']} completed ++++")

@torch.no_grad()
def validate_epoch(
        test_loader: torch.utils.data.DataLoader,
        model: nn.Module,
        state: dict,
        zero_shot_attributes: dict,
        criterion: nn.Module,
        metrics_dict: dict,
        fabric: Fabric,
        experiment_config: configs.ExperimentConfig
):
    fabric.barrier()
    fabric.print(f"++++ Validating epoch {state['current_epoch']} ++++")
    model.eval()
    for idx, batch in enumerate(test_loader):
        batch = fabric.to_device(batch)
        output = model(batch)
        zero_shot_embeddings = model(zero_shot_attributes=zero_shot_attributes)
        loss = criterion(**output, fabric=fabric)
        update_metrics("test", loss, output, batch, zero_shot_embeddings, metrics_dict, fabric)

        if idx % (experiment_config.verbose_step * 10) == 0 or idx == len(test_loader) - 1:
            fabric.print(
                f"+++ Epoch: {state['current_epoch']:04d} "
                f"| Test Step: {idx}/{len(test_loader)}"
                f"| Test Loss: {metrics_dict['test_loss'].compute().item()}"
                f"| Test Label Hit Rate {metrics_dict['test_label_hit_rate'].compute().item()}"
                f"| Test Precision@k {metrics_dict['test_precision_at_k'].compute().item()}"
                f"| Test Cosine Similarity {metrics_dict['test_cosine_similarity'].compute().item()}"
                f" +++"
            )

def fit(
        train_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader,
        model: nn.Module,
        state: dict,
        zero_shot_attributes: dict,
        optimizer: torch.optim.Optimizer,
        criterion: nn.Module,
        scheduler: torch.optim.lr_scheduler,
        fabric: Fabric,
        metrics_dict: dict,
        experiment_config: configs.ExperimentConfig
):
    """Fit the model"""
    fabric.print(f"+++ Number of raw training batches: {len(train_loader)}")
    fabric.print(f"+++ Number of raw test batches: {len(test_loader)}")

    for epoch in range(state["current_epoch"], experiment_config.dataset_config.num_epochs):
        train_epoch(
            train_loader,
            model,
            state,
            zero_shot_attributes,
            optimizer,
            criterion,
            scheduler,
            fabric,
            metrics_dict,
            experiment_config
        )
        validate_epoch(
            test_loader,
            model,
            state,
            zero_shot_attributes,
            criterion,
            metrics_dict,
            fabric,
            experiment_config
        )

        log_epoch_metrics(metrics_dict, fabric)
        state["start_iteration"] = 0
        state["current_epoch"] += 1

    # save final version of the mode & metadata
    save_model_weights(model, fabric, experiment_config)
    save_metadata(fabric, experiment_config)
    fabric.logger.finalize("success")

def get_current_epoch_iteration(
        train_loader: torch.utils.data.DataLoader,
        fabric: Fabric,
        experiment_config: configs.ExperimentConfig
) -> int:
    """
    Get epoch iteration on resume for enumerate to start
    """
    state_dict = train_loader.state_dict()
    num_train_samples = experiment_config.dataset_config.num_train_samples

    if "num_samples_yielded" in state_dict:
        # litdata format
        samples_seen = state_dict["num_samples_yielded"]
        current_epoch = state_dict["current_epoch"]
    else:
        # mosaic-ml format
        samples_seen = state_dict["sample_in_epoch"]
        current_epoch = state_dict["epoch"]

    iteration = (samples_seen * len(train_loader)) // num_train_samples
    fabric.print(f"++++ Resuming epoch {current_epoch} from iteration: {iteration} ++++")
    return iteration

def get_initial_state(model, optimizer, train_loader, scheduler):
    """Get the state of the session"""
    return {
        "current_epoch": 0,
        "current_step": 0,  # model update
        "iteration": 0,  # dataset-wise
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "train_loader": train_loader.state_dict(),
        "scheduler": scheduler.state_dict()
    }

def get_state(
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        train_loader: torch.utils.data.DataLoader,
        scheduler: torch.optim.lr_scheduler,
        fabric: Fabric,
        experiment_config: configs.ExperimentConfig
):
    """Loads a checkpoint from a given file into state from remote storage
    """
    current_checkpoints = sorted(utils.io.glob_gcs(
        os.path.join(experiment_config.checkpoint_path, "*-state.ckpt")
    ), reverse=True)
    if current_checkpoints:
        fabric.print(
            f"++++ Resuming training using state from: {os.path.basename(current_checkpoints[0])} ++++"
        )
        state = fabric.load(current_checkpoints[0].replace("gs://", "/gcs/"))
        model.load_state_dict(state["model"])
        optimizer.load_state_dict(state["optimizer"])
        train_loader.load_state_dict(state["train_loader"])
        scheduler.load_state_dict(state["scheduler"])
        # modify start iteration state in case of resuming training
        state["start_iteration"] = get_current_epoch_iteration(train_loader, fabric, experiment_config)
    else:
        fabric.print(f"++++ Starting Training from epoch: 0 ++++")
        state = get_initial_state(model, optimizer, train_loader, scheduler)
        state["start_iteration"] = 0

    # extra metadata
    state["num_batches_per_epoch"] = len(train_loader)
    state["num_steps_per_epoch"] = len(train_loader) // experiment_config.training_config.accum_freq
    return state

def launch_training(experiment_config: configs.ExperimentConfig):
    """Set up the training environment, launcher and data loaders
    """
    logger = CSVLogger(
        root_dir=experiment_config.artifacts_path,
        name="logs", version="clip"
    )
    fabric = set_fabric(logger, experiment_config)

    fabric.print("++++ Setting up model and optimizer ++++")
    model = models.get_model(experiment_config)
    optimizer = optimization.get_optimizer(model, experiment_config)
    model, optimizer = fabric.setup(model, optimizer)
    lr_scheduler = optimization.get_learning_rate(optimizer, experiment_config)

    fabric.print("++++ Setting up dataloaders ++++")
    train_loader = data.get_dataloader("train", experiment_config)
    test_loader = data.get_dataloader("test", experiment_config)
    fabric.barrier()

    # get and resume state if available
    state = get_state(
        model, optimizer, train_loader, lr_scheduler, fabric, experiment_config
    )

    # get losses
    criterion = losses.get_loss(experiment_config, fabric)
    models.print_trainable_parameters(model, fabric, experiment_config.training_config)

    # get zero shot tokens
    fabric.print("++++ Loading zero shot attribute tokens ++++")
    zero_shot_attributes = fabric.to_device(
        torch.load(experiment_config.zero_shot_tokens_fn)
    )

    # define metrics to track during training/validation
    metrics_dict = metrics.get_metrics(fabric.device)

    # sync before starting training
    fabric.barrier()

    fit(
        train_loader,
        test_loader,
        model,
        state,
        zero_shot_attributes,
        optimizer,
        criterion,
        lr_scheduler,
        fabric,
        metrics_dict,
        experiment_config
    )

Environment

Expected behavior

The training dataloader finish the epoch and the rest of the code continues its execution

tchaton commented 1 month ago

Hey @miguelalba96. Thanks for reporting this issue.

Would you mind printing the length of each dataset, dataloader on each rank. Usually it hangs when a rank have more data than others. It shouldn't happen but I want to exclude this eventuallity.

Do you think you could share a tiny reproducible example with dummy data for me to debug ?

Best, T.C

miguelalba96 commented 1 month ago

when printing the ranks per node and on each: len(dataset),len(dataloader)` I get homogeneous number of samples on each:

image

Not sure how to reproduce this problem, I will check. I also noticed that when I load the state to resume training using the function I wrote above get_state, the dataloader doesn't seem to resume properly and iterates all over again through the data until it hangs šŸ¤”:

image

tchaton commented 1 month ago

Hey @miguelalba96, any chance you could create a reproducible Studio on https://lightning.ai/ that I can duplicate to investigate what's happening. Otherwise, it is hard for me to help you.