onlyzdd / ecg-diagnosis

Deep learning for 12-lead ECG interpretation
125 stars 35 forks source link

question about `output_list`, `labels_list` and `scheduler.step()` in `main.py` #10

Closed rdyan0053 closed 1 year ago

rdyan0053 commented 1 year ago

https://github.com/onlyzdd/ecg-diagnosis/blob/dfa9033d5ae7be135db63ff567e66fdb2b86d76d/main.py#L45C1-L47C23

Hi, I have the question about the above code in main.py. What is the role of labels_list and output_list here? I know in evaluate function, these two variables is to compute the F1 score. But in the train function, it seems that these two variables are not needed.

Another question is why the scheduler.step() is commented here. is this a little mistake?

Looking forward to your answer, thank you!

rdyan0053 commented 1 year ago

In addition, here

Do you need to add code with torch.no_grad() here.

def evaluate(dataloader, net, args, criterion, device):
    print('Validating...')
    net.eval()
    running_loss = 0
    output_list, labels_list = [], []
    # add code torch.no_grad()
    for _, (data, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        output = net(data)
        loss = criterion(output, labels)
        running_loss += loss.item()
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    print('Loss: %.4f' % running_loss)
onlyzdd commented 1 year ago

https://github.com/onlyzdd/ecg-diagnosis/blob/dfa9033d5ae7be135db63ff567e66fdb2b86d76d/main.py#L45C1-L47C23

Hi, I have the question about the above code in main.py. What is the role of labels_list and output_list here? I know in evaluate function, these two variables is to compute the F1 score. But in the train function, it seems that these two variables are not needed.

Another question is why the scheduler.step() is commented here. is this a little mistake?

Looking forward to your answer, thank you!

@rdyan0053 Simply answering your questions, you can remove these code lines or comment them out as they are not used.

onlyzdd commented 1 year ago

In addition, here

Do you need to add code with torch.no_grad() here.

def evaluate(dataloader, net, args, criterion, device):
    print('Validating...')
    net.eval()
    running_loss = 0
    output_list, labels_list = [], []
    # add code torch.no_grad()
    for _, (data, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        output = net(data)
        loss = criterion(output, labels)
        running_loss += loss.item()
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    print('Loss: %.4f' % running_loss)

@rdyan0053 You are right. We can wrap the code with torch.no_grad() which can give us a slight speed up while evaluating. But as for the results, it does not matter as we zero the gradients before the backward propagation in a training epoch.

rdyan0053 commented 1 year ago

Thanks!