Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
But when I try to load the checkpoint:
predictor = TorchPredictor.from_checkpoint(result.checkpoint,BertClassifier(),use_gpu=True)
result in error Unexpected key(s) in state_dict: "module.bert.embeddings.position_ids"...
Similar issue with #36639 but I don't think it's solved.
Also, when I looked at the documentation. It mentioned that using session.report() will strip all the module. prefixes, so the saved state_dict shouldn't have `module.' prefixes anymore. Am I correct?
Versions / Dependencies
python==3.8.15
ray==2.6.1
Reproduction script
from ray.air import session, Checkpoint
from ray.air.config import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchPredictor, TorchCheckpoint
from ray.train.data_parallel_trainer import _load_checkpoint_dict
import ray.train as train
from ray.train.predictor import Predictor
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoModel
from typing import Dict
import torch
import time
import ray
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
dataset = load_dataset("imdb")
def tokenize(batch: pd.DataFrame,tokenizer: AutoTokenizer) -> dict:
result = tokenizer(
list(batch["text"]),
truncation=True,
max_length=128,
padding="max_length",
return_tensors="pt",
)
result["labels"] = batch["label"].copy()
return dict(result)
class BertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(BertClassifier, self).__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input: Dict):
outputs = self.bert(input["input_ids"], input["attention_mask"])
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
return logits
def train_loop_per_worker(config: Dict):
st = time.time()
tokenizer = AutoTokenizer.from_pretrained(config["model_name"], padding_side="left")
train_dataset = dataset['train'].map(tokenize, fn_kwargs={"tokenizer":tokenizer},batched=True)
test_dataset = dataset['test'].map(tokenize, fn_kwargs={"tokenizer":tokenizer},batched=True)
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Create PyTorch DataLoader
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512)
# [1] Prepare Dataloader for distributed training
# Shard the datasets among workers and move batches to the correct device
# =======================================================================
train_dataloader = ray.train.torch.prepare_data_loader(train_loader)
test_dataloader = ray.train.torch.prepare_data_loader(test_loader)
model = BertClassifier(config["model_name"],2)
model = train.torch.prepare_model(model)
lr = config["lr"]
epochs = config["epochs"]
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Train for 10 epochs over the data. We'll use a shuffle buffer size
# of 10k elements, and prefetch up to 10 batches of size 128 each.
ckpt = session.get_checkpoint()
if ckpt:
ckpt_dict = ckpt.to_dict()
model.load_state_dict(ckpt_dict["model"])
model = train.torch.prepare_model(model)
for e in range(epochs):
model.train()
for batch in train_dataloader:
input_ids, attention_mask, label = batch["input_ids"],batch["attention_mask"],batch["labels"]
# Compute prediction error
pred = model({"input_ids":input_ids,"attention_mask":attention_mask})
loss = loss_fn(pred,label)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss.item()
print(f"loss: {loss:>7f}")
state_dict = model.state_dict()
checkpoint = TorchCheckpoint.from_dict(
dict(epoch=e, model=state_dict)
)
#checkpoint.to_directory("/home/jackson/mass_search_prod/docs")
session.report({"loss":loss},checkpoint=checkpoint)
print(f"Total elapsed time: {time.time()-st}")
trainer = TorchTrainer(
train_loop_per_worker,
train_loop_config={"model_name":"prajjwal1/bert-tiny",
"lr": 1e-3,
"epochs":10},
scaling_config=ScalingConfig(num_workers=2,
use_gpu=True,
resources_per_worker={"CPU":1,
"GPU":1},
placement_strategy="SPREAD"),
run_config = RunConfig(
storage_path="/home/jackson/Ray_Tutorial/ray_results")
)
result = trainer.fit()
predictor = TorchPredictor.from_checkpoint(result.checkpoint,BertClassifier("prajjwal1/bert-tiny",2),use_gpu=True)
Closing this since prefix stripping logic has been removed in 2.7, and the user has full flexibility of what stripped/unstripped data to save in the checkpoint.
What happened + What you expected to happen
During training, I would like to save my model checkpoint. I tried saving the checkpoint using:
But when I try to load the checkpoint:
predictor = TorchPredictor.from_checkpoint(result.checkpoint,BertClassifier(),use_gpu=True)
result in errorUnexpected key(s) in state_dict: "module.bert.embeddings.position_ids"...
Similar issue with #36639 but I don't think it's solved.Also, when I looked at the documentation. It mentioned that using
session.report()
will strip all themodule.
prefixes, so the savedstate_dict
shouldn't have `module.' prefixes anymore. Am I correct?Versions / Dependencies
python==3.8.15 ray==2.6.1
Reproduction script
Issue Severity
High: It blocks me from completing my task.