zhyczy / FedTP

This is official implementation of FedTP (TNNLS.2023) [https://ieeexplore.ieee.org/abstract/document/10130784].
MIT License
46 stars 6 forks source link

Need your solution #6

Open parmahadi opened 10 months ago

parmahadi commented 10 months ago

Hi, I try modificated the code to collect data from client each global round. The code like this : client_data_per_round = []

    for round in range(args.comm_round):
        logger.info("in comm round: %d" %round)

        hnet.train()
        grads_update = []

        arr = np.arange(args.n_parties)
        np.random.shuffle(arr)
        selected = arr[:int(args.n_parties * args.sample)]
        weights = hnet(torch.tensor(np.array([selected]), dtype=torch.long).to(device),False)

        global_para = global_model.state_dict()
        if round == 0:
            if args.is_same_initial:
                for ix in range(len(selected)):
                    node_weights = weights[ix]
                    idx = selected[ix]
                    nets[idx].load_state_dict(global_para)
                    nets[idx].load_state_dict(node_weights, strict=False)
        else:
            for ix in range(len(selected)):
                node_weights = weights[ix]
                idx = selected[ix]
                nets[idx].load_state_dict(global_para)
                nets[idx].load_state_dict(node_weights, strict=False)

        if args.dataset == 'shakespeare':
            client_data = local_train_net_per(nets, selected, args, train_dl_global, test_dl_global, logger, device=device)
            client_data_per_round.append(client_data)

        else:
            client_data = local_train_net_per(nets, selected, args, net_dataidx_map_train, net_dataidx_map_test, logger, device=device)
            client_data_per_round.append(client_data)

But when the simulation was running, the out put like this C:\Users\parma\anaconda3\envs\vehicle\Lib\site-packages\numpy\core\fromnumeric.py:3504: RuntimeWarning: Mean of empty slice. return _methods._mean(a, axis=axis, dtype=dtype, C:\Users\parma\anaconda3\envs\vehicle\Lib\site-packages\numpy\core_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide ret = ret.dtype.type(ret / rcount) C:\Users\parma\anaconda3\envs\vehicle\Lib\site-packages\numpy\core_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, C:\Users\parma\anaconda3\envs\vehicle\Lib\site-packages\numpy\core_methods.py:163: RuntimeWarning: invalid value encountered in divide arrmean = um.true_divide(arrmean, div, out=arrmean, C:\Users\parma\anaconda3\envs\vehicle\Lib\site-packages\numpy\core_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide ret = ret.dtype.type(ret / rcount) INFO:root:Test Acc = nan% +- nan%

may you give some solutions to me? thank you

Oussamab21 commented 9 months ago

Hi I am facing the same problem any solutions ? Thanks