RUCAIBox / RecBole

A unified, comprehensive and efficient recommendation library
https://recbole.io/
MIT License
3.27k stars 590 forks source link

咨询模型参数修改后更新被覆盖的问题 #1973

Closed ithok closed 5 months ago

ithok commented 5 months ago

您好!想咨询一下模型参数修改的问题。 RT,我基于bole建立了一个GCN网络,并添加了自定义的mask,这个mask会以指定的频率(如20epoch/次)进行剪枝,然而我发现当我在模型参数上直接进行修改后,它们会在下一个epoch中被复原,仿佛剪枝操作从来没有发生过,单个epoch内剪枝操作是发生了,部分问题代码如下:

if self.model.lotteryflag and epoch_idx % 10 ==0 and epoch_idx>10: print("prune!") self.model.update_norm_adj() 这段代码在trainner里

update_norm_adj() 的部分内容

重置adj_mask 得到一个新的adj_mask

    ones = torch.ones_like(self.adj_mask_train)
    zeros = torch.zeros_like(self.adj_mask_train)
    adj_mask = torch.where(self.adj_mask_train.abs() >= adj_thre, ones, zeros)
    # 更新adj_mask
    self.adj_mask_train = torch.nn.Parameter(adj_mask,requires_grad=True)

我不清楚这究竟是什么问题,希望能得到您的解答和解决方案,万分感谢!

ithok commented 5 months ago

我将这个问题提交到了pytorch forum: https://discuss.pytorch.org/t/manually-modified-parameter-during-trainning-gradients-are-not-being-correctly-updated-help/195782/2 似乎目前定位到的问题有可能是和optimizer有关,因为即使我在backward后显式清零了grad,参数依旧会被更新

Fotiligner commented 5 months ago

@ithok 您好,我们这里暂时无法复现您的问题,请问可以提供更详细的代码或者问题描述吗?

ithok commented 5 months ago

@ithok 您好,我们这里暂时无法复现您的问题,请问可以提供更详细的代码或者问题描述吗?

您好,可以参考这里:https://discuss.pytorch.org/t/unexpected-gradient-presence-in-zero-valued-positions-of-trainable-parameters-after-value-replacement-during-training/196005 详细的代码我贴在了之前的链接里,但是贴全部比较困难,我实现了一个最低可运行的代码,它在问题方面是一致的,但可以被直接运行: import torch import torch.nn as nn import torch.optim as optim


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.adj = nn.Parameter(torch.Tensor(5, 5).fill_(0.5), requires_grad=False)  # fixed adj
        self.adj_mask = nn.Parameter(torch.ones(5, 5), requires_grad=True)  # trainable adj_mask

    def forward(self):
        masked_adj = torch.mul(self.adj,self.adj_mask) # mask adj with adj_mask
        return masked_adj

def print_tensor_info(tensor):
    num_zeros = (tensor == 0).sum().item()
    num_non_zeros = tensor.numel() - num_zeros
    max_val, min_val = tensor.max(), tensor.min()
    has_nan = torch.isnan(tensor).any().item()

    print(f"Zero values: {num_zeros}")
    print(f"Non-zero values: {num_non_zeros}")
    print(f"Max value: {max_val.item()}")
    print(f"Min value: {min_val.item()}")
    print(f"Has NaN: {has_nan}")

# prune the mask
def prune_mask(mask, percentage):
    with torch.no_grad():
        # print("before prune")
        # print_tensor_info(mask)
        adj_mask_tensor = mask.flatten()
        # print("before_adj_mask_tensor.shape: ", adj_mask_tensor.shape)
        nonzero = torch.abs(adj_mask_tensor) > 0
        adj_mask = adj_mask_tensor[nonzero]  # 13264 - 2708
        # print("before_adj_mask.shape: ", adj_mask.shape)
        # print(adj_mask)
        adj_total = adj_mask.shape[0]
        adj_y, adj_i = torch.sort(adj_mask.abs())
        adj_thre_index = int(adj_total * percentage)
        adj_thre = adj_y[adj_thre_index]
        # print("adj_thre", adj_thre)
        abs_values = torch.abs(mask)
        index = abs_values >= adj_thre
        mask.data[index] = 1
        mask.data[~index] = 0
        # print("After prune")
        # print_tensor_info(mask)
    # print("-------")

model = SimpleModel()

optimizer = optim.Adam([model.adj_mask], lr=10)

# input_data = torch.randn(5, 5)
target = torch.randn(5, 5)

pruneflag = False
for epoch in range(41):

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before forward propagation")
    #     print_tensor_info(model.adj_mask)
    masked_adj = model()
    loss = nn.MSELoss()(masked_adj, target)

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before backward propagation")
    #     print_tensor_info(model.adj_mask)
    optimizer.zero_grad()
    # if pruneflag:
    #     print("Check whether the adj_mask gradient has been cleared before backpropagation")
    #     print_tensor_info(model.adj_mask.grad)
    loss.backward()
    if(pruneflag):
        print("Check adj_mask gradient after backpropagation")
        print_tensor_info(model.adj_mask.grad)
        # optimizer.zero_grad()
        # print_tensor_info(model.adj_mask.grad)
    optimizer.step()
    if (pruneflag):
        print("Check adj_mask after backpropagation")
        print_tensor_info(model.adj_mask)
        pruneflag = False

    # prune every 20 epoch
    if epoch % 20 == 0:
        print("----epoch" + str(epoch))
        prune_mask(model.adj_mask, 0.2)
        pruneflag = True

我的问题是: (1) 修改 adj_mask 张量后,在反向传播期间,我观察到 adj_mask 上的梯度未按预期更新 - 通过 Hadamard积乘以零的元素意外地接收到了梯度,这是反直觉的。

(2) 此外,尽管手动将梯度设置为零,但 adj_mask 继续在我的代码中按照 Optimizer.step() 进行更新(注释掉的部分代码)。 我假设这可能与 Adam 优化器固有的动量有关,因为当我切换到 SGD 时,这个问题不会发生。 尽管如此,这个假设并没有办法解释第一个问题。

所以严格来说,不是参数回归到了之前的状态而是参数接收到了不该有的梯度.