WwZzz / easyFL

An experimental platform for federated learning.
Apache License 2.0
519 stars 88 forks source link

使用Efficientnet-b0导致qfedavg失效 #57

Open Xiaxuanxuan opened 7 months ago

Xiaxuanxuan commented 7 months ago

感谢您的联邦框架!!非常简洁并且方便移植!!! 不过有一个问题想麻烦您回答,当我将model换成efficientnet-b0,在cifar10数据集上使用qfedavg、fedfv、fedprox时,会出现自始至终loss不变的问题,这是我设定的model from torch import nn from flgo.utils.fmodule import FModule from efficientnet_pytorch import EfficientNet

class Model(FModule): def init(self): super(Model, self).init() pretrained = True self.base_model = ( EfficientNet.from_pretrained("efficientnet-b0") if pretrained else EfficientNet.from_name("efficientnet-b0") )

self.base_model=torchvision.models.efficientnet_v2_s(pretrained=pretrained)

    nftrs = self.base_model._fc.in_features
    # print("Number of features output by EfficientNet", nftrs)
    self.base_model._fc = nn.Linear(nftrs, 10)

def forward(self, x):
    # Convolution layers
    x = self.base_model.extract_features(x)
    # Pooling and final linear layer
    feature_x = self.base_model._avg_pooling(x)
    if self.base_model._global_params.include_top:
        x = feature_x.flatten(start_dim=1)
        x = self.base_model._dropout(x)
        x = self.base_model._fc(x)
    return x

def init_local_module(object): pass

def init_global_module(object): if 'Server' in object.class.name: object.model = Model().to(object.device) 会出现这样的结果 issue

WwZzz commented 7 months ago

你好,之前有人在flgo交流群中提出了同样的问题,该问题是因为qfedavg的代码中使用norm接口直接计算模型的范数,norm结构默认调用的是flgo.utiles.fmodule._model_dict_norm,而model.state_dict()中通常包含了统计量参数,使得带bn层的模型由该接口得到的范数都会非常大,若是更新过程中除以了模型范数的话,会出现这种模型更新被放缩到0的情形,我这里贴上我修复后的qfedavg代码,稍后会将该更新整合到flgo中

WwZzz commented 7 months ago

`"""This is a non-official implementation of 'Fair Resource Allocation in Federated Learning' (http://arxiv.org/abs/1905.10497). And this implementation refers to the official github repository https://github.com/litian96/fair_flearn """ import flgo.algorithm.fedbase as fedbase import flgo.utils.fmodule as fmodule import copy

class Server(fedbase.BasicServer): def initialize(self, *args, **kwargs): self.init_algo_para({'q': 1.0})

def iterate(self):
    self.selected_clients = self.sample()
    res = self.communicate(self.selected_clients)
    self.model = self.model - fmodule._model_sum(res['dk']) / sum(res['hk'])
    return len(self.received_clients) > 0

class Client(fedbase.BasicClient): def unpack(self, package): model = package['model'] self.global_model = copy.deepcopy(model) return model

def pack(self, model):
    Fk = self.test(self.global_model, 'train')['loss'] + 1e-8
    L = 1.0 / self.learning_rate
    delta_wk = L * (self.global_model - model)
    dk = (Fk ** self.q) * delta_wk
    norm_dwk = 0.0
    for p in delta_wk.parameters():
        norm_dwk += (p**2).sum()
    hk = self.q * (Fk ** (self.q - 1)) * (norm_dwk) + L * (Fk ** self.q)
    self.global_model = None
    return {'dk': dk, 'hk': hk}

`

WwZzz commented 7 months ago

将涉及到norm计算的地方替换成基于model.parameter计算可以修复该问题,但是由于bn和niid在联邦学习中具有天然冲突,建议直接使用不带bn或是将bn替换成gn的模型