awslabs / dgl-lifesci

Python package for graph neural networks in chemistry and biology
Apache License 2.0
714 stars 147 forks source link

Small bug in classfication_inference for csv_data_configuration #224

Open HFooladi opened 8 months ago

HFooladi commented 8 months ago

There is a small bug in the examples/property_prediction/csv_data_configuration/classification_inference.py

On line 37, the output of predict function is logit (so it can change from -inf to inf theoretically).

batch_pred = predict(args, model, bg)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())

So, first it should be converted to a number between [0, 1] with sigmoid function, and then it should be used for hard or soft classification label.

batch_logit = predict(args, model, bg)
batch_pred = torch.sigmoid(batch_logit)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())
mufeili commented 8 months ago

Nice catch! Thank you for the report. Unfortunately, I've left AWS and cannot update the codebase or approve PR from others. You may modify your own fork if you need to use this functionality.