Open huiguhean opened 1 month ago
是的,我看了几套模型的代码,大都是这样实现的。虽然pytorch提供了元学习的类,但是改起来也太麻烦了。。我尝试手动改了一下,但是还是有问题哥们可以一块交流一下 def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True): """ Train the model using MAML method. Args: model: Any model support_images: several task support images support_labels: several support labels query_images: several query images query_labels: several query labels inner_step: support data training step args: ArgumentParser optimizer: optimizer is_train: whether train
Returns: meta loss, meta accuracy
"""
meta_loss = []
meta_acc = []
# support_images,support_labels,所有任务的数据
# 遍历每个任务的数据
for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):
#第二种写法,自己写的,但是不知道是不是正确的
inner_model = copy.deepcopy(model)
for _ in range(inner_step):
support_logit = inner_model(support_image)
support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
grads = torch.autograd.grad(support_loss, inner_model.parameters(), create_graph=True)
# 这个时候inner_model没有梯度
for param, grad in zip(inner_model.parameters(), grads):
param.data -= args.inner_lr * grad
# 这个时候inner_model没有梯度
#query_logit = model.functional_forward(query_image, fast_weights)
query_logit = inner_model(query_image)
query_prediction = torch.max(query_logit, dim=1)[1]
query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
meta_loss.append(query_loss)
meta_acc.append(query_acc.data.cpu().numpy())
# 外层模型梯度清0
optimizer.zero_grad()
meta_loss = torch.stack(meta_loss).mean()
meta_acc = np.mean(meta_acc)
if is_train:
# 手动清零梯度,清除内层模型
for param in inner_model.parameters():
param.grad = None # 或者使用 param.grad.zero_()
meta_loss.backward() # 其实这里内层模型存储了梯度,但是我需要把这个模型的梯度清0
for inner_param, outer_param in zip(inner_model.parameters(), model.parameters()):
outer_param.grad = inner_param.grad # 将内层模型的梯度赋值给外层模型的参数
# 这个优化器的定义是关于外层模型的,
optimizer.step()
return meta_loss, meta_acc
加个q交流一下2832485959
这个方面,就没有办法继续帮你了,我也只是照着作者的思路实现了一个最简单的感知机。
maml实现中,必须把模型参数摘出来,重写forward,吗?对于复杂一点的模型太麻烦了,因为在qry阶段模型已经经过内循环更新,得到loss也只能对更新过的model用参数更新,没办法直接对初始模型init_model参数更新。 您采用的方法是把模型参数摘出来,并在模型中用了function_forward代替模型本身的forward进行损失计算和摘出来的参数更新,用function_forward代替模型本身forward,从而实现的最后的对模型初始参数的更新。 对于复杂模型,有无简单方法?