Closed Nintorac closed 2 years ago
From the log files, it appears that clients and Server were proceeding normally until somehow an ABORT command was received.
ABORT could happen in several cases:
From the info in the log files, I suspect the run was aborted forcefully by an ABORT command issued from the user/script. If you used script, can you please paste its content?
The fact that you got them consistently at the 33 minute mark indicates that there is probably some kind of timeout in your running script (that runs on your PC). We ran into this problem before. If this is the case, you just need to extend the timeout.
Yea, this seems to be the case, I am using the AdminAPIRunner.run
command which has a default timeout of 2000 seconds. (33.33 minutes)
Thanks :)
Glad it's an easy fix. :-)
Is it worth changing this default to 0 and having timeouts as opt in?
Setting the value to 0 gets an instant abort
Hi all, I have the same problem. My training is every time randomly aborted after 36 global epochs approximately. I am using a modified version of the Cifar10 trainer for my training and I can't find a timeout in there. @yanchengnv do you by chance have an idea where the abortion comes from in my case?
Trainer:
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from monai.metrics import DiceMetric
import numpy as np
import torch
import torch.optim as optim
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from pt.networks.cifar10_nets import ModerateCNN
from pt.utils.cifar10_dataset import CIFAR10_Idx
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learner_spec import Learner
from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType
from nvflare.app_common.pt.pt_fedproxloss import PTFedProxLoss
import json
import logging
import os
import torch
import torch.distributed as dist
from monai.data import (
CacheDataset,
DataLoader,
load_decathlon_datalist,
partition_dataset,
)
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointSaver,
LrScheduleHandler,
MeanDice,
StatsHandler,
TensorBoardStatsHandler,
ValidationHandler,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss, DiceFocalLoss
from monai.networks.layers import Norm
from monai.networks.nets import UNet, BasicUNet, SegResNet
from monai.transforms import (
Activations,
AddChanneld,
AsDiscrete,
Activationsd,
AsDiscreted,
Compose,
CropForegroundd,
EnsureChannelFirstd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
ScaleIntensityd,
Spacingd,
ToTensord,
RandGaussianNoised,
RandShiftIntensityd,
HistogramNormalized,
Resized
)
from torch.nn.parallel import DistributedDataParallel
from monai.handlers import from_engine
import custom.ImageDataloader as ImageDataloader
from custom.model import get_model
from custom.processing import Move_channel_to_front
class CIFAR10Learner(Learner):
def __init__(
self,
dataset_root: str = "./dataset",
aggregation_epochs: int = 1,
train_task_name: str = AppConstants.TASK_TRAIN,
submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL,
lr: float = 1e-2,
fedproxloss_mu: float = 0.0,
central: bool = False,
):
"""Simple CIFAR-10 Trainer.
Args:
dataset_root: directory with CIFAR-10 data.
aggregation_epochs: the number of training epochs for a round. Defaults to 1.
train_task_name: name of the task to train the model.
submit_model_task_name: name of the task to submit the best local model.
Returns:
a Shareable with the updated local model after running `execute()`
or the best local model depending on the specified task.
"""
super().__init__()
# trainer init happens at the very beginning, only the basic info regarding the trainer is set here
# the actual run has not started at this point
self.aggregation_epochs = aggregation_epochs
self.train_task_name = train_task_name
self.lr = lr
self.fedproxloss_mu = fedproxloss_mu
self.submit_model_task_name = submit_model_task_name
self.best_acc = 0.0
self.central = central
self.initialized = False
# Epoch counter
self.epoch_of_start_time = 0
self.epoch_global = 0
def initialize(self, parts: dict, fl_ctx: FLContext):
# when the run starts, this is where the actual settings get initialized for trainer
# Set the paths according to fl_ctx
# wf_config_file_name=fl_args.train_config,
# app_root = config_root
self.dataset_root = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_root_dir() #dataset_root
self.app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
fl_args = fl_ctx.get_prop(FLContextKey.ARGS)
self.client_id = fl_ctx.get_identity_name()
self.log_info(
fl_ctx,
f"Client {self.client_id} initialized at \n {self.app_root} \n with args: {fl_args}",
)
self.local_model_file = os.path.join(self.app_root, "local_model.pt")
self.best_local_model_file = os.path.join(self.app_root, "best_local_model.pt")
# Set local tensorboard writer - to be replaced by event
self.writer = SummaryWriter(self.app_root)
# Set datalist, here the path and filename are hard-coded, can also be fed as an argument
site_idx_file_name = os.path.join(self.app_root, fl_args.train_config)
with open(os.path.join(self.app_root, fl_args.train_config)) as file:
wf_config = json.load(file)
self.wf_config = wf_config
self.max_epochs = wf_config["max_epochs"]
self.learning_rate = wf_config["learning_rate"]
self.data_list_base_dir = self.app_root
self.data_list = wf_config["data_list_json_file"]
self.val_interval = wf_config["val_interval"]
self.ckpt_dir = wf_config["ckpt_dir"]
self.amp = wf_config["amp"]
self.use_gpu = wf_config["use_gpu"]
self.multi_gpu = wf_config["multi_gpu"]
self.local_rank = 0
self.identity_name = fl_ctx.get_identity_name()
self.data_list_json_file = self.data_list + self.identity_name + ".json"
site_idx_file_name = os.path.join(self.dataset_root, self.client_id + ".npy")
self.log_info(fl_ctx, f"IndexList Path: {site_idx_file_name}")
# set the training-related parameters
# can be replaced by a config-style block
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = UNet(
dimensions=2,
in_channels=1,
out_channels=33,
channels=(32, 64, 128, 256,512),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(self.device)
self.transform_train = Compose(
[
LoadImaged(keys=("image", "label")),
Move_channel_to_front(keys=("image")),
Resized(keys=("image", "label"),spatial_size=(224,224)),
AddChanneld(keys=["image", "label"]),
ScaleIntensityd(
keys="image",
minv=0.0,
maxv=1.0,
),
RandGaussianNoised(keys=("image"), prob=0),
RandShiftIntensityd(keys=("image"), offsets=0.1, prob=0),
# HistogramNormalized(keys=("image"), min=0.0, max=1.0),
ToTensord(keys=("image", "label")),
]
)
self.transform_valid = Compose(
[
LoadImaged(keys=("image", "label")),
Move_channel_to_front(keys=("image")),
Resized(keys=("image", "label"),spatial_size=(224,224)),
AddChanneld(keys=["image", "label"]),
ScaleIntensityd(
keys="image",
minv=0.0,
maxv=1.0,
),
# HistogramNormalized(keys=("image"), min=0.0, max=1.0),
ToTensord(keys=("image", "label")),
]
)
# set datalist
self.train_dataset = load_decathlon_datalist(
os.path.join(self.data_list_base_dir, self.data_list_json_file),
is_segmentation=True,
data_list_key="training",
base_dir=self.data_list_base_dir,
)
self.valid_dataset = load_decathlon_datalist(
os.path.join(self.data_list_base_dir, self.data_list_json_file),
is_segmentation=True,
data_list_key="validation",
base_dir=self.data_list_base_dir,
)
# set datalist
self.train_loader = ImageDataloader.ImageDataset(
data=self.train_dataset,
transform=self.transform_train
)
self.valid_loader = ImageDataloader.ImageDataset(
data=self.valid_dataset,
transform= self.transform_valid
)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.criterion = DiceFocalLoss(include_background=False, to_onehot_y=False, softmax=True)
if self.fedproxloss_mu > 0:
self.log_info(fl_ctx, f"using FedProx loss with mu {self.fedproxloss_mu}")
self.criterion_prox = PTFedProxLoss(mu=self.fedproxloss_mu)
self.post_transform = Compose(
[
Activations(softmax=True),
AsDiscrete(threshold=0.5)
]
)
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
self.inferer = SimpleInferer()
def finalize(self, fl_ctx: FLContext):
# collect threads, close files here
pass
def local_train(self, fl_ctx, train_loader, model_global, abort_signal: Signal, val_freq: int = 0):
for epoch in range(self.aggregation_epochs):
if abort_signal.triggered:
return
self.model.train()
epoch_len = len(train_loader)
self.epoch_global = self.epoch_of_start_time + epoch
self.log_info(fl_ctx, f"Local epoch {self.client_id}: {epoch + 1}/{self.aggregation_epochs} (lr={self.lr})")
for i,batch_data in enumerate(train_loader):
inputs, labels = batch_data["image"].to(self.device), batch_data["label"].to(self.device)
if abort_signal.triggered:
return
# zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
# FedProx loss term
if self.fedproxloss_mu > 0:
fed_prox_loss = self.criterion_prox(self.model, model_global)
loss += fed_prox_loss
loss.backward()
self.optimizer.step()
self.writer.add_scalar("train_loss", loss.item(), epoch_len * self.epoch_global + i)
if val_freq > 0 and epoch % val_freq == 0:
acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_local_model")
if acc > self.best_acc:
self.save_model(is_best=True)
def save_model(self, is_best=False):
# save model
model_weights = self.model.state_dict()
save_dict = {"model_weights": model_weights, "epoch": self.epoch_global}
if is_best:
save_dict.update({"best_acc": self.best_acc})
torch.save(save_dict, self.best_local_model_file)
else:
torch.save(save_dict, self.local_model_file)
def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
# Check abort signal
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
try:
# get round information
current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS)
self.log_info(fl_ctx, f"Current/Total Round: {current_round + 1}/{total_rounds}")
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")
# update local model weights with received weights
dxo = from_shareable(shareable)
global_weights = dxo.data
# Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
local_var_dict = self.model.state_dict()
model_keys = global_weights.keys()
for var_name in local_var_dict:
if var_name in model_keys:
weights = global_weights[var_name]
# reshape global weights to compute difference later on
global_weights[var_name] = np.reshape(weights, local_var_dict[var_name].shape)
# update the local dict
local_var_dict[var_name] = torch.as_tensor(global_weights[var_name])
except Exception as e:
raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
self.model.load_state_dict(local_var_dict)
# local steps
epoch_len = len(self.train_loader)
self.log_info(fl_ctx, f"Local steps per epoch: {epoch_len}")
# make a copy of model_global as reference for potential FedProx loss
if self.fedproxloss_mu > 0:
model_global = copy.deepcopy(self.model)
for param in model_global.parameters():
param.requires_grad = False
else:
model_global = None
# local train
self.local_train(
fl_ctx=fl_ctx,
train_loader=self.train_loader,
model_global=model_global,
abort_signal=abort_signal,
val_freq=1 if self.central else 0,
)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
self.epoch_of_start_time += self.aggregation_epochs
# perform valid after local train
acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_local_model")
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
self.log_info(fl_ctx, f"val_acc_local_model: {acc:.4f}")
# save model
self.save_model(is_best=False)
if acc > self.best_acc:
self.save_model(is_best=True)
# compute delta model, global model has the primary key set
local_weights = self.model.state_dict()
model_diff = {}
for name in global_weights:
if name not in local_weights:
continue
model_diff[name] = local_weights[name].cpu().numpy() - global_weights[name]
if np.any(np.isnan(model_diff[name])):
self.system_panic(f"{name} weights became NaN...", fl_ctx)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
# build the shareable
dxo = DXO(data_kind=DataKind.WEIGHT_DIFF, data=model_diff)
dxo.set_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND, epoch_len)
self.log_info(fl_ctx, "Local epochs finished. Returning shareable")
return dxo.to_shareable()
def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
# Retrieve the best local model saved during training.
if model_name == ModelName.BEST_MODEL:
model_data = None
try:
# load model to cpu as server might or might not have a GPU
model_data = torch.load(self.best_local_model_file, map_location="cpu")
except Exception as e:
self.log_error(fl_ctx, f"Unable to load best model: {e}")
# Create DXO and shareable from model data.
if model_data:
dxo = DXO(data_kind=DataKind.WEIGHTS, data=model_data["model_weights"])
return dxo.to_shareable()
else:
# Set return code.
self.log_error(fl_ctx, f"best local model not found at {self.best_local_model_file}.")
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
else:
raise ValueError(f"Unknown model_type: {model_name}") # Raised errors are caught in LearnerExecutor class.
def local_valid(self, valid_loader, abort_signal: Signal, tb_id=None):
self.model.eval()
with torch.no_grad():
for val_data in valid_loader:
val_images, val_labels = val_data["image"].to(self.device), val_data["label"].to(self.device)
# define sliding window size and batch size for windows inference
roi_size = (96, 96)
sw_batch_size = 4
# val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, self.model)
val_outputs = self.inferer(val_images, self.model)
val_outputs = [self.post_transform(i) for i in decollate_batch(val_outputs)]
val_labels = decollate_batch(val_labels)
# compute metric for current iteration
self.dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = self.dice_metric.aggregate().item()
print("METRIC", metric)
# reset the status
self.dice_metric.reset()
if tb_id:
self.writer.add_scalar(tb_id, metric, self.epoch_global)
return metric
def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
# Check abort signal
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
# get round information
self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")
# update local model weights with received weights
dxo = from_shareable(shareable)
global_weights = dxo.data
# Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
local_var_dict = self.model.state_dict()
model_keys = global_weights.keys()
for var_name in local_var_dict:
if var_name in model_keys:
weights = torch.as_tensor(global_weights[var_name], device=self.device)
try:
# update the local dict
local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape))
except Exception as e:
raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
self.model.load_state_dict(local_var_dict)
validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE)
if validate_type == ValidateType.BEFORE_TRAIN_VALIDATE:
# perform valid before local train
global_acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_dice_global_model")
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
self.log_info(fl_ctx, f"val_dice_global_model: {global_acc:.4f}")
return DXO(data_kind=DataKind.METRICS, data={MetaKey.INITIAL_METRICS: global_acc}, meta={}).to_shareable()
elif validate_type == ValidateType.MODEL_VALIDATE:
# perform valid
train_acc = self.local_valid(self.train_loader, abort_signal)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
self.log_info(fl_ctx, f"training dice: {train_acc:.4f}")
val_acc = self.local_valid(self.valid_loader, abort_signal)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
self.log_info(fl_ctx, f"validation dice: {val_acc:.4f}")
self.log_info(fl_ctx, "Evaluation finished. Returning shareable")
val_results = {"train_dice": train_acc, "val_dice": val_acc}
metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results)
return metric_dxo.to_shareable()
else:
return make_reply(ReturnCode.VALIDATE_TYPE_UNKNOWN)
config_sever:
{
"format_version": 2,
"min_clients": 2,
"num_rounds": 40,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "persistor",
"name": "PTFileModelPersistor",
"args": {
"model": {
"path": "monai.networks.nets.unet.UNet",
"args": {
"dimensions":2,
"in_channels":1,
"out_channels":33,
"channels":[32, 64, 128, 256,512],
"strides":[2, 2, 2, 2],
"num_res_units":2,
"norm": "batch"
}
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"name":"InTimeAccumulateWeightedAggregator",
"args": {
}
},
{
"id": "model_selector",
"path": "custom.selection.IntimeModelSelectionHandler",
"args": {}
},
{
"id": "model_locator",
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "formatter",
"path": "custom.pt_formatter.PTFormatter",
"args": {}
},
{
"id": "json_generator",
"path": "custom.validation_json_generator.ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
{
"id": "scatter_and_gather",
"name": "ScatterAndGather",
"args": {
"min_clients" : 2,
"num_rounds" : 40,
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0,
"ignore_result_error": true
}
},
{
"id": "cross_site_model_eval",
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
}
}
]
}
config_client:
{
"format_version": 2,
"executors": [
{
"tasks": [ "train", "submit_model", "validate"],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.learner_executor.LearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],
"task_result_filters": [
],
"task_data_filters": [
],
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_learner.CIFAR10Learner",
"args": {
"dataset_root": "{DATASET_ROOT}",
"aggregation_epochs": 5,
"lr": 0.0001
}
}
]
}
Hey, I am getting random aborts that I can't find the source of. I have a server and two clients.
I am getting them consistently at the 33 minute mark which makes it a bit hard to debug.
I have included some logs and the server conf since the server seems to be the culprit. Any ideas what might be causing this?
Tail of client 1 logs
tail of client 2 logs
tail of server logs
server config