stefanonardo / pytorch-esn

An Echo State Network module for PyTorch.
MIT License
205 stars 43 forks source link

The grad is none, please help me. #14

Open yiyi123123132 opened 1 year ago

yiyi123123132 commented 1 year ago

I don't know why the grad is none.? this is my net code. CNN+ESN

class CNN_ESN(nn.Module): def init(self, output_size,channel_num, device,drop_prob=0.5): super(CNN_ESN, self).init() self.conv = nn.Sequential( nn.Conv2d(channel_num, 64, (5, 5), padding='same'), nn.ReLU(), nn.Conv2d(64, 128, (4, 4), padding='same'), nn.ReLU(), nn.Conv2d(128, 256, (4, 4), padding='same'), nn.ReLU(), nn.Conv2d(256, 64, (1, 1), padding='same'), nn.ReLU(), nn.MaxPool2d((2, 2)), nn.Flatten(), nn.Linear(1024, 512) ).to(device) self.esn = ESN(4,128,128,output_steps='mean', readout_training='svd').to(device) self.fc = nn.Linear(128,output_size).to(device) self.dropout = nn.Dropout(drop_prob) self.sig = nn.Sigmoid() self.washout_rate = 0.2 def forward(self, x): conv_result = self.conv(x).reshape(128,x.shape[0],-1) conv_result =self.dropout(conv_result) washout_lst = [int(self.washout_rate conv_result.size(0))] conv_result.size(1) out,hn = self.esn(conv_result,washout_lst) hn = hn.transpose(1,0) logit = self.fc(hn).squeeze(1) return self.sig(logit).squeeze(1)

    for inputs, labels in train_loader: 
        net.train()
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()  
        output = net(inputs) 
        pred = torch.round(output)
        loss = loss_func(output, labels)
        pred_list = [float(i) for i in pred.tolist()]
        for p, l in zip(pred_list, labels.tolist()):
            tr_pre.append(int(p))
            tr_lab.append(l)
        tr_loss.append(loss.item())
        loss.backward()
        optimizer.step()

and I print the grad as below, the 2nd epoch reulst:

conv.0.weight None conv.0.bias None conv.2.weight None conv.2.bias None conv.4.weight None conv.4.bias None conv.6.weight None conv.6.bias None conv.10.weight None conv.10.bias None esn.reservoir.weight_ih_l0 None esn.reservoir.weight_hh_l0 None esn.reservoir.bias_ih_l0 None esn.readout.weight None esn.readout.bias None fc.weight tensor([[ 0.0480, -0.0079, 0.0929, 0.1358, 0.1022, -0.1127, 0.0495, 0.1056, -0.0923, 0.0720, 0.1122, 0.0139, -0.0619, 0.0796, 0.1433, 0.0295, 0.0884, -0.0504, 0.0305, 0.0264, 0.1352, 0.0467, -0.0607, -0.0363, -0.0114, -0.1393, -0.0917, 0.0194, 0.1076, -0.0713, -0.0487, 0.0433, -0.0875, 0.0212, 0.1007, -0.0711, 0.1098, 0.0577, 0.0607, 0.0299, 0.0380, 0.0955, 0.0062, -0.0620, -0.0463, -0.0354, 0.1050, -0.0920, 0.0742, -0.0550, -0.1270, -0.0597, 0.0736, 0.0246, 0.0521, -0.0866, -0.0065, -0.0764, 0.0087, -0.0810, 0.0551, 0.0999, 0.1078, -0.0082, -0.0940, -0.0628, -0.0624, 0.0779, -0.0107, -0.0069, 0.0793, -0.0318, 0.0086, -0.1427, 0.0617, 0.0839, -0.0904, 0.0535, 0.0678, -0.0232, 0.0361, -0.0065, -0.0806, 0.0824, -0.0336, 0.0707, -0.1042, -0.0336, -0.1039, 0.0715, -0.1403, -0.0068, -0.0912, -0.1043, 0.0455, -0.0223, 0.0448, 0.0321, 0.0844, -0.0734, -0.0050, 0.1083, 0.0027, -0.0897, -0.1176, -0.1059, 0.1006, 0.0873, -0.0715, -0.1130, 0.0707, -0.0015, 0.0257, -0.0536, -0.0769, -0.0671, -0.0329, -0.0666, -0.0631, 0.1100, -0.1013, -0.0392, -0.1062, 0.1276, -0.0686, -0.1200, -0.0669, 0.0424]]) fc.bias tensor([0.1520])

could tell me the reason?

yiyi123123132 commented 1 year ago

net = md.CNN_ESN(output_size,channel_num, device,drop_prob=0.5)