FedCP在IoT数据集上训练报错 #177

============= Running time: 0th =============
Creating server and clients ...
  (conv1): Sequential(
    (0): Conv2d(9, 32, kernel_size=(1, 9), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(1, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 9), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(1, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Sequential(
    (0): Linear(in_features=3712, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=12, bias=True)

Join ratio / total clients: 1.0 / 9
Finished creating server and clients.

-------------Round number: 0-------------
当使用HARCNN模型时,clientcp.py代码中in_dim = list(args.model.base.parameters())[-1].shape[0]不再适用,需要手动设置为in_dim = 1664in_dim = 3712

也可以采用GPFL的clientgpfl.py代码self.feature_dim = list(self.model.head.parameters())[0].shape[1]

一般来说,需要做的修改除了上述的in_dim以外,还得修改其他对head部分的处理代码,比如set_head_g函数中headw_p的获得就不能再直接使用head.weight.data.clone(),因为HARCNN中的head有多个FC层(有多个weight matrix),此时需要使用matmul将head中的所有FC层的weight按顺序相乘,化为一个weight matrix作为context生成的基础。


    def set_head_g(self, head):
        headw_ps = []
        for name, mat in self.model.model.head.named_parameters():
            if 'weight' in name:
        headw_p = headw_ps[-1]
        for mat in headw_ps[-2::-1]:
            headw_p = torch.matmul(headw_p, mat)
        self.context = torch.sum(headw_p, dim=0, keepdim=True)

        for new_param, old_param in zip(head.parameters(), self.model.head_g.parameters()):
            old_param.data = new_param.data.clone()


