Closed ZixuanFeng-NYU closed 1 year ago
Thanks for your detailed analysis of the loss calculation in our training script.
I appreciate your suggestions, and I also want to respectfully clarify that our implementation strictly adheres to the original methodology as outlined in the Chemprop code. In essence, the way the loss calculation is currently performed is precisely how it's conducted in the original Chemprop code.
However, if you believe that a different approach to the loss calculation might better serve your specific use case, you're absolutely welcome to modify the script and experiment. Do keep in mind, though, that changes in the loss computation method might lead to different dynamics during the training process, which might necessitate adjustments to other hyperparameters, such as the learning rate, to maintain optimal performance.
In chemprop/train/run_training.py: ... for epoch in range(args.epochs): info(f'train_data_size: {len(train_data)}; val_data_size: {len(val_data)}; test_data_size: {len(test_data)}') info(f'Epoch {epoch}')
If I set "info(f'Train loss = {loss:.6f}')" , the loss shown is much smaller than I expected.
So, in chemprop/train/train.py, it writes: ..... ..... mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])
..... preds = model(step, prompt, batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) class_weights mask else: loss = loss_func(preds, targets) class_weights mask loss = loss.sum() / mask.sum()
........
Log and/or add to tensorboard
Here I think you are summing batch average loss and divide it by the number of samples, you could print the results of "loss = loss.sum() / mask.sum()" and you would see the loss calculated here is the average loss in a batch, not sum of loss in a batch.
Isn't it right that it should loss_sum+=(batch_avergae_loss * batch_size) and then loss_sum / iter_count if you want to calculate training loss? Otherwise, it could divide sum of batch_avergae_loss by the number of batches.