OpenMined / KotlinSyft

The official Syft worker for secure on-device machine learning
https://www.openmined.org
Apache License 2.0
85 stars 28 forks source link

The README.md is too old, could anyone update this ? #323

Open xiechunHonor opened 3 years ago

xiechunHonor commented 3 years ago

Could anyone give a new example of FL on mobile? "PySyft/packages/syft/examples/federated-learning/model-centric/mcfl_create_plan_mobile.ipynb" Thanks!

LuckyHFC commented 3 years ago

您好,我现在也需要构建一个自定义网络模型,然后在自己的手机上进行模型训练。当我经过一番尝试之后,尽管程序能够运行。但是,模型一直没有收敛。希望有人能够写一个新的示例,实现更加复杂网络的反向传播(最好能实现卷积网络的反向传播)。谢谢! QQ图片20210601094408


class MLP(sy.Module):
def __init__(self, torch_ref) -> None:
super(MLP, self).__init__(torch_ref=torch_ref)
self.fc1 = torch_ref.nn.Linear(784, 256)
self.relu = torch_ref.nn.ReLU()
self.fc2 = torch_ref.nn.Linear(256, 128)
self.fc3 = torch_ref.nn.Linear(128, 64)
self.fc4 = torch_ref.nn.Linear(64, 10)
def forward(self, x):
self.z1 = self.fc1(x)
self.a1 = self.relu(self.z1)
self.z2 = self.fc2(self.a1)
self.a2 = self.relu(self.z2)
self.z3 = self.fc3(self.a2)
self.a3 = self.relu(self.z3)
return self.fc4(self.a3)
def backward(self, X, error):
    # compute the l1_grad, l1_weight and l1_bias
    l1_grad = (error @ self.fc4.state_dict()["weight"]) * (self.a3 > 0).float()
    l1_weight = l1_grad.t() @ X
    l1_bias = l1_grad.sum(0)
    # compute the l2_grad, l2_weight and l2_bias
    l2_grad = (l1_grad @ self.fc3.state_dict()["weight"]) * (self.a2 > 0).float()
    l2_weight = l2_grad.t() @ self.a3
    l2_bias = l2_grad.sum(0)
    # compute the l3_grad, l3_weight and l3_bias
    l3_grad = (l2_grad @ self.fc2.state_dict()["weight"]) * (self.a1 > 0).float()
    l3_weight = l3_grad.t() @ self.a2
    l3_bias = l3_grad.sum(0)
    # compute the l4_weight and l4_bias
    l4_weight = error.t() @ self.a1
    l4_bias = error.sum(0)
    return l1_weight, l1_bias, l2_weight, l2_bias, l3_weight, l3_bias, l4_weight, l4_bias

def softmax_cross_entropy_with_logits(self, logits, target, batch_size):
    probs = self.torch_ref.softmax(logits, dim=1)
    loss = -(target * self.torch_ref.log(probs)).sum(dim=1).mean()
    loss_grad = (probs - target) / batch_size
    return loss, loss_grad

def accuracy(self, logits, targets, batch_size):
    pred = self.torch_ref.argmax(logits, dim=1)
    targets_idx = self.torch_ref.argmax(targets, dim=1)
    acc = pred.eq(targets_idx).sum().float() / batch_size
    return acc