NVIDIA / NVFlare

NVIDIA Federated Learning Application Runtime Environment
https://nvidia.github.io/NVFlare/
Apache License 2.0
630 stars 179 forks source link

Random aborts #316

Closed Nintorac closed 2 years ago

Nintorac commented 2 years ago

Hey, I am getting random aborts that I can't find the source of. I have a server and two clients.

I am getting them consistently at the 33 minute mark which makes it a bit hard to debug.

I have included some logs and the server conf since the server seems to be the culprit. Any ideas what might be causing this?

Tail of client 1 logs

727/727 - 83s - loss: 0.1298 - tp: 313.0000 - fp: 83.0000 - tn: 676119.0000 - fn: 199013.0000 - accuracy: 0.7726 - precision: 0.7904 - recall: 0.0016 - auroc: 0.7065 - aupr: 0.3900 - 83s/epoch - 114ms/step
Epoch 4/10
check status from process listener......
check status from process listener......
2022-03-17 00:04:08,196 - ClientRunner - INFO - [run=3]: ABORT (RUN) command received
2022-03-17 00:04:08,196 - ClientRunner - INFO - [run=3, task_name=train, task_id=c93db1f9-1405-426e-aea5-20f6b4c292ae]: triggered task_abort_signal to stop task 'train'
2022-03-17 00:04:08,196 - ClientRunner - INFO - [run=3, task_name=train, task_id=c93db1f9-1405-426e-aea5-20f6b4c292ae]: fired ABORT_TASK event to abort current task train
2022-03-17 00:04:08,196 - ClientRunner - INFO - [run=3]: ABORT (RUN) requests end run events sequence
2022-03-17 00:04:10,198 - ProcessExecutor - INFO - Client training was terminated.
2022-03-17 00:04:10,502 - ProcessExecutor - INFO - process finished with execution code: -9

tail of client 2 logs

727/727 - 79s - loss: 0.1298 - tp: 316.0000 - fp: 95.0000 - tn: 676107.0000 - fn: 199010.0000 - accuracy: 0.7726 - precision: 0.7689 - recall: 0.0016 - auroc: 0.7069 - aupr: 0.3901 - 79s/epoch - 109ms/step
Epoch 4/10
check status from process listener......
check status from process listener......
check status from process listener......
2022-03-17 00:04:08,164 - ClientRunner - INFO - [run=3]: ABORT (RUN) command received
2022-03-17 00:04:08,164 - ClientRunner - INFO - [run=3, task_name=train, task_id=a8ae57ef-7387-4760-b81e-1499ce45594c]: triggered task_abort_signal to stop task 'train'
2022-03-17 00:04:08,164 - ClientRunner - INFO - [run=3, task_name=train, task_id=a8ae57ef-7387-4760-b81e-1499ce45594c]: fired ABORT_TASK event to abort current task train
2022-03-17 00:04:08,164 - ClientRunner - INFO - [run=3]: ABORT (RUN) requests end run events sequence
2022-03-17 00:04:10,166 - ProcessExecutor - INFO - Client training was terminated.
2022-03-17 00:04:10,326 - ProcessExecutor - INFO - process finished with execution code: -9

tail of server logs

2022-03-16 23:59:30,480 - ServerRunner - INFO - [run=3, wf=scatter_and_gather, peer=nvflare-client-1-6bddcdd77c-pw7mk, peer_run=3, task_name=train, task_id=a8ae57ef-7387-4760-b81e-1499ce45594c]: sent task assignment to client
2022-03-16 23:59:30,482 - FederatedServer - INFO - Return task:train to client:nvflare-client-1-6bddcdd77c-pw7mk --- (d36e54cb-c7c0-4210-98c0-35d38d97149d) 
2022-03-17 00:04:10,273 - ServerEngine - INFO - Abort the server app run.
2022-03-17 00:04:10,274 - ServerRunner - INFO - [run=3, wf=scatter_and_gather]: asked to abort - triggered abort_signal to stop the RUN
2022-03-17 00:04:10,394 - SigOptScatterAndGather - INFO - [run=3, wf=scatter_and_gather]: Abort signal received. Exiting at round 2.
2022-03-17 00:04:10,394 - SigOptScatterAndGather - INFO - [run=3, wf=scatter_and_gather]: Abort signal received. Exiting at round 2.
2022-03-17 00:04:10,408 - SigOptScatterAndGather - INFO - [run=3, wf=scatter_and_gather]: task train exit with status TaskCompletionStatus.ABORTED
2022-03-17 00:04:11,362 - sigopt.print - INFO - Run finished, view it on the SigOpt dashboard at https://app.sigopt.com/run/194705
2022-03-17 00:04:11,362 - ServerRunner - INFO - [run=3, wf=scatter_and_gather]: Workflow: scatter_and_gather finalizing ...
2022-03-17 00:04:13,410 - ServerRunner - INFO - [run=3, wf=scatter_and_gather]: ABOUT_TO_END_RUN fired
2022-03-17 00:04:13,411 - ServerRunner - INFO - [run=3, wf=scatter_and_gather]: END_RUN fired
2022-03-17 00:04:13,411 - ServerRunner - INFO - [run=3, wf=scatter_and_gather]: Server runner finished.
2022-03-17 00:04:15,304 - FederatedServer - INFO - Server app stopped.

server config

{
    "format_version": 2,
    "server": {
        "heart_beat_timeout": 600
    },
    "task_data_filters": [],
    "task_result_filters": [],
    "components": [
        {
            "id": "persistor",
            "path": "sepsis_model.TF2ModelPersistor",
            "args": {
                "save_name": "tf2weights.pickle",
                "model_config": {
                    "learning_rate": 0.01,
                    "dropout": 0.5
                }
            }
        },
        {
            "id": "shareable_generator",
            "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
            "args": {}
        },
        {
            "id": "aggregator",
            "path": "nvflare.app_common.aggregators.accumulate_model_aggregator.AccumulateWeightedAggregator",
            "args": {
                "expected_data_kind": "WEIGHTS"
            }
        }
    ],
    "workflows": [
        {
            "id": "scatter_and_gather",
            "path": "sepsis_model.SigOptScatterAndGather",
            "args": {
                "min_clients": 2,
                "num_rounds": 10,
                "start_round": 0,
                "wait_time_after_min_received": 10,
                "aggregator_id": "aggregator",
                "persistor_id": "persistor",
                "shareable_generator_id": "shareable_generator",
                "train_task_name": "train",
                "train_timeout": 0,
                "run_id": null
            }
        }
    ]
}
yanchengnv commented 2 years ago

From the log files, it appears that clients and Server were proceeding normally until somehow an ABORT command was received.

ABORT could happen in several cases:

From the info in the log files, I suspect the run was aborted forcefully by an ABORT command issued from the user/script. If you used script, can you please paste its content?

yanchengnv commented 2 years ago

The fact that you got them consistently at the 33 minute mark indicates that there is probably some kind of timeout in your running script (that runs on your PC). We ran into this problem before. If this is the case, you just need to extend the timeout.

Nintorac commented 2 years ago

Yea, this seems to be the case, I am using the AdminAPIRunner.run command which has a default timeout of 2000 seconds. (33.33 minutes)

Thanks :)

https://github.com/NVIDIA/NVFlare/blob/f360fba591941989a59d30f7be3477e635433a48/nvflare/fuel/hci/client/fl_admin_api_runner.py#L126

yanchengnv commented 2 years ago

Glad it's an easy fix. :-)

Nintorac commented 2 years ago

Is it worth changing this default to 0 and having timeouts as opt in?

Setting the value to 0 gets an instant abort

LSnyd commented 2 years ago

Hi all, I have the same problem. My training is every time randomly aborted after 36 global epochs approximately. I am using a modified version of the Cifar10 trainer for my training and I can't find a timeout in there. @yanchengnv do you by chance have an idea where the abortion comes from in my case?

Trainer:

# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
from monai.metrics import DiceMetric
import numpy as np
import torch
import torch.optim as optim
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from pt.networks.cifar10_nets import ModerateCNN
from pt.utils.cifar10_dataset import CIFAR10_Idx
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learner_spec import Learner
from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType
from nvflare.app_common.pt.pt_fedproxloss import PTFedProxLoss
import json
import logging
import os

import torch
import torch.distributed as dist
from monai.data import (
    CacheDataset,
    DataLoader,
    load_decathlon_datalist,
    partition_dataset,
)
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
    CheckpointSaver,
    LrScheduleHandler,
    MeanDice,
    StatsHandler,
    TensorBoardStatsHandler,
    ValidationHandler,
)

from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss, DiceFocalLoss
from monai.networks.layers import Norm
from monai.networks.nets import UNet, BasicUNet, SegResNet
from monai.transforms import (
        Activations,
    AddChanneld,
    AsDiscrete,
    Activationsd,
    AsDiscreted,
    Compose,
    CropForegroundd,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    ScaleIntensityd,
    Spacingd,
    ToTensord,
    RandGaussianNoised,
    RandShiftIntensityd,
    HistogramNormalized,
    Resized
)

from torch.nn.parallel import DistributedDataParallel
from monai.handlers import from_engine

import custom.ImageDataloader as ImageDataloader

from custom.model import get_model

from custom.processing import Move_channel_to_front

class CIFAR10Learner(Learner):
    def __init__(
        self,
        dataset_root: str = "./dataset",
        aggregation_epochs: int = 1,
        train_task_name: str = AppConstants.TASK_TRAIN,
        submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL,
        lr: float = 1e-2,
        fedproxloss_mu: float = 0.0,
        central: bool = False,
    ):
        """Simple CIFAR-10 Trainer.

        Args:
            dataset_root: directory with CIFAR-10 data.
            aggregation_epochs: the number of training epochs for a round. Defaults to 1.
            train_task_name: name of the task to train the model.
            submit_model_task_name: name of the task to submit the best local model.

        Returns:
            a Shareable with the updated local model after running `execute()`
            or the best local model depending on the specified task.
        """
        super().__init__()
        # trainer init happens at the very beginning, only the basic info regarding the trainer is set here
        # the actual run has not started at this point

        self.aggregation_epochs = aggregation_epochs
        self.train_task_name = train_task_name
        self.lr = lr
        self.fedproxloss_mu = fedproxloss_mu
        self.submit_model_task_name = submit_model_task_name
        self.best_acc = 0.0
        self.central = central
        self.initialized = False

        # Epoch counter
        self.epoch_of_start_time = 0
        self.epoch_global = 0

    def initialize(self, parts: dict, fl_ctx: FLContext):
        # when the run starts, this is where the actual settings get initialized for trainer

        # Set the paths according to fl_ctx
        # wf_config_file_name=fl_args.train_config,
       # app_root = config_root
        self.dataset_root = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_root_dir() #dataset_root
        self.app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
        fl_args = fl_ctx.get_prop(FLContextKey.ARGS)
        self.client_id = fl_ctx.get_identity_name()
        self.log_info(
            fl_ctx,
            f"Client {self.client_id} initialized at \n {self.app_root} \n with args: {fl_args}",
        )

        self.local_model_file = os.path.join(self.app_root, "local_model.pt")
        self.best_local_model_file = os.path.join(self.app_root, "best_local_model.pt")

        # Set local tensorboard writer - to be replaced by event
        self.writer = SummaryWriter(self.app_root)

        # Set datalist, here the path and filename are hard-coded, can also be fed as an argument
        site_idx_file_name = os.path.join(self.app_root, fl_args.train_config)

        with open(os.path.join(self.app_root, fl_args.train_config)) as file:
            wf_config = json.load(file)

        self.wf_config = wf_config
        self.max_epochs = wf_config["max_epochs"]
        self.learning_rate = wf_config["learning_rate"]
        self.data_list_base_dir = self.app_root 
        self.data_list = wf_config["data_list_json_file"]
        self.val_interval = wf_config["val_interval"]
        self.ckpt_dir = wf_config["ckpt_dir"]
        self.amp = wf_config["amp"]
        self.use_gpu = wf_config["use_gpu"]
        self.multi_gpu = wf_config["multi_gpu"]
        self.local_rank = 0
        self.identity_name = fl_ctx.get_identity_name()
        self.data_list_json_file = self.data_list + self.identity_name + ".json"
        site_idx_file_name = os.path.join(self.dataset_root, self.client_id + ".npy")

        self.log_info(fl_ctx, f"IndexList Path: {site_idx_file_name}")

        # set the training-related parameters
        # can be replaced by a config-style block
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=33,
            channels=(32, 64, 128, 256,512),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)

        self.transform_train = Compose(
            [
                LoadImaged(keys=("image", "label")),
                Move_channel_to_front(keys=("image")),
                Resized(keys=("image", "label"),spatial_size=(224,224)),
                AddChanneld(keys=["image", "label"]),
                ScaleIntensityd(
                    keys="image",
                    minv=0.0,
                    maxv=1.0,
                ),
                RandGaussianNoised(keys=("image"), prob=0),
                RandShiftIntensityd(keys=("image"), offsets=0.1, prob=0),
          #      HistogramNormalized(keys=("image"), min=0.0, max=1.0),
                ToTensord(keys=("image", "label")),
            ]
        )

        self.transform_valid = Compose(

            [
                LoadImaged(keys=("image", "label")),
                Move_channel_to_front(keys=("image")),
                Resized(keys=("image", "label"),spatial_size=(224,224)),
                AddChanneld(keys=["image", "label"]),
                ScaleIntensityd(
                    keys="image",
                    minv=0.0,
                    maxv=1.0,
                ),
             #   HistogramNormalized(keys=("image"), min=0.0, max=1.0),
                ToTensord(keys=("image", "label")),
            ]
        )

                # set datalist
        self.train_dataset = load_decathlon_datalist(
            os.path.join(self.data_list_base_dir, self.data_list_json_file),
            is_segmentation=True,
            data_list_key="training",
            base_dir=self.data_list_base_dir,
        )
        self.valid_dataset = load_decathlon_datalist(
            os.path.join(self.data_list_base_dir, self.data_list_json_file),
            is_segmentation=True,
            data_list_key="validation",
            base_dir=self.data_list_base_dir,
        )

        # set datalist
        self.train_loader = ImageDataloader.ImageDataset(
            data=self.train_dataset, 
            transform=self.transform_train
        )

        self.valid_loader = ImageDataloader.ImageDataset(
            data=self.valid_dataset,
                 transform= self.transform_valid
        )

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = DiceFocalLoss(include_background=False, to_onehot_y=False, softmax=True)
        if self.fedproxloss_mu > 0:
            self.log_info(fl_ctx, f"using FedProx loss with mu {self.fedproxloss_mu}")
            self.criterion_prox = PTFedProxLoss(mu=self.fedproxloss_mu)

        self.post_transform = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold=0.5)
            ]
        )

        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

        self.inferer = SimpleInferer()

    def finalize(self, fl_ctx: FLContext):
        # collect threads, close files here
        pass

    def local_train(self, fl_ctx, train_loader, model_global, abort_signal: Signal, val_freq: int = 0):
        for epoch in range(self.aggregation_epochs):
                 if abort_signal.triggered:
                    return

                self.model.train()
                epoch_len = len(train_loader)
                self.epoch_global = self.epoch_of_start_time + epoch
                self.log_info(fl_ctx, f"Local epoch {self.client_id}: {epoch + 1}/{self.aggregation_epochs} (lr={self.lr})")
                for i,batch_data in enumerate(train_loader):

                    inputs, labels = batch_data["image"].to(self.device), batch_data["label"].to(self.device)

                    if abort_signal.triggered:
                        return

                    # zero the parameter gradients
                    self.optimizer.zero_grad()
                    # forward + backward + optimize
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)

                    # FedProx loss term
                    if self.fedproxloss_mu > 0:
                        fed_prox_loss = self.criterion_prox(self.model, model_global)
                        loss += fed_prox_loss

                    loss.backward()
                    self.optimizer.step()
                    self.writer.add_scalar("train_loss", loss.item(), epoch_len * self.epoch_global + i)
                if val_freq > 0 and epoch % val_freq == 0:
                    acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_local_model")
                    if acc > self.best_acc:
                        self.save_model(is_best=True)

    def save_model(self, is_best=False):
        # save model
        model_weights = self.model.state_dict()
        save_dict = {"model_weights": model_weights, "epoch": self.epoch_global}
        if is_best:
            save_dict.update({"best_acc": self.best_acc})
            torch.save(save_dict, self.best_local_model_file)
        else:
            torch.save(save_dict, self.local_model_file)

    def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
        # Check abort signal
        if abort_signal.triggered:
            return make_reply(ReturnCode.TASK_ABORTED)

        try:

            # get round information
            current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
            total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS)
            self.log_info(fl_ctx, f"Current/Total Round: {current_round + 1}/{total_rounds}")
            self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")

            # update local model weights with received weights
            dxo = from_shareable(shareable)
            global_weights = dxo.data

            # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
            local_var_dict = self.model.state_dict()
            model_keys = global_weights.keys()
            for var_name in local_var_dict:
                if var_name in model_keys:
                        weights = global_weights[var_name]

                        # reshape global weights to compute difference later on
                        global_weights[var_name] = np.reshape(weights, local_var_dict[var_name].shape)
                        # update the local dict
                        local_var_dict[var_name] = torch.as_tensor(global_weights[var_name])
                    except Exception as e:
                        raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
            self.model.load_state_dict(local_var_dict)

            # local steps
            epoch_len = len(self.train_loader)
            self.log_info(fl_ctx, f"Local steps per epoch: {epoch_len}")

            # make a copy of model_global as reference for potential FedProx loss
            if self.fedproxloss_mu > 0:
                model_global = copy.deepcopy(self.model)
                for param in model_global.parameters():
                    param.requires_grad = False
            else:
                model_global = None

            # local train
            self.local_train(
                fl_ctx=fl_ctx,
                train_loader=self.train_loader,
                model_global=model_global,
                abort_signal=abort_signal,
                val_freq=1 if self.central else 0,
            )
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)
            self.epoch_of_start_time += self.aggregation_epochs

            # perform valid after local train
            acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_local_model")
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)
            self.log_info(fl_ctx, f"val_acc_local_model: {acc:.4f}")

            # save model
            self.save_model(is_best=False)
            if acc > self.best_acc:
                self.save_model(is_best=True)

            # compute delta model, global model has the primary key set
            local_weights = self.model.state_dict()

            model_diff = {}
            for name in global_weights:
                if name not in local_weights:
                    continue
                model_diff[name] = local_weights[name].cpu().numpy() - global_weights[name]
                if np.any(np.isnan(model_diff[name])):
                    self.system_panic(f"{name} weights became NaN...", fl_ctx)
                    return make_reply(ReturnCode.EXECUTION_EXCEPTION)

            # build the shareable
            dxo = DXO(data_kind=DataKind.WEIGHT_DIFF, data=model_diff)
            dxo.set_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND, epoch_len)

            self.log_info(fl_ctx, "Local epochs finished. Returning shareable")

            return dxo.to_shareable()

    def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
        # Retrieve the best local model saved during training.
        if model_name == ModelName.BEST_MODEL:
            model_data = None
            try:
                # load model to cpu as server might or might not have a GPU
                model_data = torch.load(self.best_local_model_file, map_location="cpu")
            except Exception as e:
                self.log_error(fl_ctx, f"Unable to load best model: {e}")

            # Create DXO and shareable from model data.
            if model_data:
                dxo = DXO(data_kind=DataKind.WEIGHTS, data=model_data["model_weights"])
                return dxo.to_shareable()
            else:
                # Set return code.
                self.log_error(fl_ctx, f"best local model not found at {self.best_local_model_file}.")
                return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
        else:
            raise ValueError(f"Unknown model_type: {model_name}")  # Raised errors are caught in LearnerExecutor class.

    def local_valid(self, valid_loader, abort_signal: Signal, tb_id=None):

        self.model.eval()
        with torch.no_grad():
            for val_data in valid_loader:
                val_images, val_labels = val_data["image"].to(self.device), val_data["label"].to(self.device)

                # define sliding window size and batch size for windows inference
                roi_size = (96, 96)
                sw_batch_size = 4
           #     val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, self.model)
                val_outputs = self.inferer(val_images, self.model)

                val_outputs = [self.post_transform(i) for i in decollate_batch(val_outputs)]
                val_labels = decollate_batch(val_labels)
                # compute metric for current iteration
                self.dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = self.dice_metric.aggregate().item()
            print("METRIC", metric)
            # reset the status
            self.dice_metric.reset()

            if tb_id:
                self.writer.add_scalar(tb_id, metric, self.epoch_global)
        return metric

    def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
        # Check abort signal
        if abort_signal.triggered:
            return make_reply(ReturnCode.TASK_ABORTED)

        # get round information
        self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}")

        # update local model weights with received weights
        dxo = from_shareable(shareable)
        global_weights = dxo.data

        # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
        local_var_dict = self.model.state_dict()
        model_keys = global_weights.keys()
        for var_name in local_var_dict:
            if var_name in model_keys:
                weights = torch.as_tensor(global_weights[var_name], device=self.device)
                try:
                    # update the local dict
                    local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape))
                except Exception as e:
                    raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
        self.model.load_state_dict(local_var_dict)

        validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE)
        if validate_type == ValidateType.BEFORE_TRAIN_VALIDATE:
            # perform valid before local train
            global_acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_dice_global_model")
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)
            self.log_info(fl_ctx, f"val_dice_global_model: {global_acc:.4f}")

            return DXO(data_kind=DataKind.METRICS, data={MetaKey.INITIAL_METRICS: global_acc}, meta={}).to_shareable()

        elif validate_type == ValidateType.MODEL_VALIDATE:
            # perform valid
            train_acc = self.local_valid(self.train_loader, abort_signal)
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)
            self.log_info(fl_ctx, f"training dice: {train_acc:.4f}")

            val_acc = self.local_valid(self.valid_loader, abort_signal)
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)
            self.log_info(fl_ctx, f"validation dice: {val_acc:.4f}")

            self.log_info(fl_ctx, "Evaluation finished. Returning shareable")

            val_results = {"train_dice": train_acc, "val_dice": val_acc}

            metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results)
            return metric_dxo.to_shareable()

        else:
            return make_reply(ReturnCode.VALIDATE_TYPE_UNKNOWN)

config_sever:

{
  "format_version": 2,
 "min_clients": 2,
  "num_rounds": 40,

  "server": {
    "heart_beat_timeout": 600
  },
  "task_data_filters": [],
  "task_result_filters": [],
  "components": [
    {
      "id": "persistor",
      "name": "PTFileModelPersistor",
      "args": {
          "model": {
          "path": "monai.networks.nets.unet.UNet",
    "args": {
            "dimensions":2,
            "in_channels":1,
            "out_channels":33,
            "channels":[32, 64, 128, 256,512],
            "strides":[2, 2, 2, 2],
            "num_res_units":2,
            "norm": "batch"
    }
          }

      }
    },
    {
      "id": "shareable_generator",
      "name": "FullModelShareableGenerator",
      "args": {}
    },
    {
      "id": "aggregator",
 "name":"InTimeAccumulateWeightedAggregator",
      "args": {
      }
    },
    {
      "id": "model_selector",
      "path": "custom.selection.IntimeModelSelectionHandler",
      "args": {}
    },
          {
      "id": "model_locator",
      "name": "PTFileModelLocator",
      "args": {
        "pt_persistor_id": "persistor"
      }
    },
        {
      "id": "formatter",
      "path": "custom.pt_formatter.PTFormatter",
      "args": {}
    },
    {
      "id": "json_generator",
      "path": "custom.validation_json_generator.ValidationJsonGenerator",
      "args": {}
    }
  ],
  "workflows": [
      {
        "id": "scatter_and_gather",
        "name": "ScatterAndGather",
        "args": {
            "min_clients" : 2,
            "num_rounds" : 40,
            "start_round": 0,
            "wait_time_after_min_received": 10,
            "aggregator_id": "aggregator",
            "persistor_id": "persistor",
            "shareable_generator_id": "shareable_generator",
            "train_task_name": "train",
            "train_timeout": 0,
            "ignore_result_error": true
        }
      },
            {
        "id": "cross_site_model_eval",
        "name": "CrossSiteModelEval",
        "args": {
          "model_locator_id": "model_locator",
          "formatter_id": "formatter",
          "submit_model_timeout": 600,
          "validation_timeout": 6000,
          "cleanup_models": true
        }
      }
  ]
}

config_client:


{
  "format_version": 2,
  "executors": [
    {
      "tasks": [ "train", "submit_model", "validate"],
      "executor": {
        "id": "Executor",
        "path": "nvflare.app_common.executors.learner_executor.LearnerExecutor",
        "args": {
          "learner_id": "cifar10-learner"
        }
      }
    }
  ],
  "task_result_filters": [
  ],
  "task_data_filters": [

  ],
  "components": [
       {
      "id": "cifar10-learner",
      "path": "pt.learners.cifar10_learner.CIFAR10Learner",
      "args": {
        "dataset_root": "{DATASET_ROOT}",
        "aggregation_epochs": 5,
        "lr": 0.0001
      }
    }
  ]
}