ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.13k stars 5.61k forks source link

[Ray Train] session.report() does not strip the module. prefixes #39039

Closed JacksonCakes closed 11 months ago

JacksonCakes commented 1 year ago

What happened + What you expected to happen

During training, I would like to save my model checkpoint. I tried saving the checkpoint using:

state_dict = model.state_dict()
# Method 1
checkpoint = TorchCheckpoint.from_state_dict(
        state_dict=state_dict
    )
# Method 2
checkpoint = TorchCheckpoint.from_dict(
        dict(epoch=e, model=state_dict)
    )

But when I try to load the checkpoint: predictor = TorchPredictor.from_checkpoint(result.checkpoint,BertClassifier(),use_gpu=True) result in error Unexpected key(s) in state_dict: "module.bert.embeddings.position_ids"... Similar issue with #36639 but I don't think it's solved.

Also, when I looked at the documentation. It mentioned that using session.report() will strip all the module. prefixes, so the saved state_dict shouldn't have `module.' prefixes anymore. Am I correct?

Versions / Dependencies

python==3.8.15 ray==2.6.1

Reproduction script

from ray.air import session, Checkpoint
from ray.air.config import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchPredictor, TorchCheckpoint
from ray.train.data_parallel_trainer import _load_checkpoint_dict
import ray.train as train
from ray.train.predictor import Predictor
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoModel
from typing import Dict
import torch
import time
import ray
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

dataset = load_dataset("imdb")

def tokenize(batch: pd.DataFrame,tokenizer: AutoTokenizer) -> dict:
        result = tokenizer(
            list(batch["text"]),
            truncation=True,
            max_length=128,
            padding="max_length",
            return_tensors="pt",
        )
        result["labels"] = batch["label"].copy()
        return dict(result)

class BertClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(BertClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input: Dict):
        outputs = self.bert(input["input_ids"], input["attention_mask"])
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        return logits

def train_loop_per_worker(config: Dict):
    st = time.time()
    tokenizer = AutoTokenizer.from_pretrained(config["model_name"], padding_side="left")
    train_dataset = dataset['train'].map(tokenize, fn_kwargs={"tokenizer":tokenizer},batched=True)
    test_dataset = dataset['test'].map(tokenize, fn_kwargs={"tokenizer":tokenizer},batched=True)

    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    # Create PyTorch DataLoader
    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=512)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_loader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_loader)

    model = BertClassifier(config["model_name"],2)
    model = train.torch.prepare_model(model)
    lr = config["lr"]
    epochs = config["epochs"]
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    # Train for 10 epochs over the data. We'll use a shuffle buffer size
    # of 10k elements, and prefetch up to 10 batches of size 128 each.
    ckpt = session.get_checkpoint()
    if ckpt:
        ckpt_dict = ckpt.to_dict()
        model.load_state_dict(ckpt_dict["model"])
    model = train.torch.prepare_model(model)
    for e in range(epochs):
        model.train()
        for batch in train_dataloader:
            input_ids, attention_mask, label = batch["input_ids"],batch["attention_mask"],batch["labels"]
            # Compute prediction error
            pred = model({"input_ids":input_ids,"attention_mask":attention_mask})
            loss = loss_fn(pred,label)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss = loss.item()

        print(f"loss: {loss:>7f}")
        state_dict = model.state_dict()
        checkpoint = TorchCheckpoint.from_dict(
        dict(epoch=e, model=state_dict)
    )
        #checkpoint.to_directory("/home/jackson/mass_search_prod/docs")
        session.report({"loss":loss},checkpoint=checkpoint)
    print(f"Total elapsed time: {time.time()-st}")

trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config={"model_name":"prajjwal1/bert-tiny",
                       "lr": 1e-3,
                       "epochs":10},
    scaling_config=ScalingConfig(num_workers=2,                     
    use_gpu=True,
    resources_per_worker={"CPU":1,
                          "GPU":1},
                          placement_strategy="SPREAD"),
    run_config = RunConfig(
                           storage_path="/home/jackson/Ray_Tutorial/ray_results")
)
result = trainer.fit()

predictor = TorchPredictor.from_checkpoint(result.checkpoint,BertClassifier("prajjwal1/bert-tiny",2),use_gpu=True)

Issue Severity

High: It blocks me from completing my task.

matthewdeng commented 11 months ago

Closing this since prefix stripping logic has been removed in 2.7, and the user has full flexibility of what stripped/unstripped data to save in the checkpoint.