Runist / torch_maml

Very simple pytorch maml implement
108 stars 13 forks source link

关于二阶导必须把模型参数摘出来? #6

Open huiguhean opened 1 month ago

huiguhean commented 1 month ago

maml实现中,必须把模型参数摘出来,重写forward,吗?对于复杂一点的模型太麻烦了,因为在qry阶段模型已经经过内循环更新,得到loss也只能对更新过的model用参数更新,没办法直接对初始模型init_model参数更新。 您采用的方法是把模型参数摘出来,并在模型中用了function_forward代替模型本身的forward进行损失计算和摘出来的参数更新,用function_forward代替模型本身forward,从而实现的最后的对模型初始参数的更新。 对于复杂模型,有无简单方法?

A-cloud-bit commented 1 month ago

是的,我看了几套模型的代码,大都是这样实现的。虽然pytorch提供了元学习的类,但是改起来也太麻烦了。。我尝试手动改了一下,但是还是有问题哥们可以一块交流一下 def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True): """ Train the model using MAML method. Args: model: Any model support_images: several task support images support_labels: several support labels query_images: several query images query_labels: several query labels inner_step: support data training step args: ArgumentParser optimizer: optimizer is_train: whether train

Returns: meta loss, meta accuracy

"""
meta_loss = []
meta_acc = []
# support_images,support_labels,所有任务的数据
# 遍历每个任务的数据
for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):

   #第二种写法,自己写的,但是不知道是不是正确的
    inner_model = copy.deepcopy(model)
    for _ in range(inner_step):
        support_logit = inner_model(support_image)
        support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
        grads = torch.autograd.grad(support_loss, inner_model.parameters(), create_graph=True)
        # 这个时候inner_model没有梯度
        for param, grad in zip(inner_model.parameters(), grads):
            param.data -= args.inner_lr * grad
        # 这个时候inner_model没有梯度
   #query_logit = model.functional_forward(query_image, fast_weights)
    query_logit = inner_model(query_image)
    query_prediction = torch.max(query_logit, dim=1)[1]
    query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
    query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
    meta_loss.append(query_loss)
    meta_acc.append(query_acc.data.cpu().numpy())

# 外层模型梯度清0
optimizer.zero_grad()
meta_loss = torch.stack(meta_loss).mean()
meta_acc = np.mean(meta_acc)

if is_train:
    # 手动清零梯度,清除内层模型
    for param in inner_model.parameters():
        param.grad = None  # 或者使用 param.grad.zero_()
    meta_loss.backward() # 其实这里内层模型存储了梯度,但是我需要把这个模型的梯度清0
    for inner_param, outer_param in zip(inner_model.parameters(), model.parameters()):
        outer_param.grad = inner_param.grad  # 将内层模型的梯度赋值给外层模型的参数
    # 这个优化器的定义是关于外层模型的,
    optimizer.step()

return meta_loss, meta_acc
A-cloud-bit commented 1 month ago

加个q交流一下2832485959

Runist commented 1 month ago

这个方面,就没有办法继续帮你了,我也只是照着作者的思路实现了一个最简单的感知机。