mims-harvard / TDC

Therapeutics Commons: Artificial Intelligence Foundation for Therapeutic Science
https://tdcommons.ai
MIT License
957 stars 169 forks source link

Permutation bug in DrugComb group evaluation #169

Closed cnedwards closed 1 year ago

cnedwards commented 1 year ago

Describe the bug Ordering of input to group evaluation for drug synergy tasks is unclear.

This effects the MLP example (https://github.com/mims-harvard/TDC/blob/main/examples/multi_pred/drugcombo/train_MLP.py), since it uses a dataloader with shuffling on.

To Reproduce The following code should show the behavior:

import numpy as np

from tdc.benchmark_group import drugcombo_group
group = drugcombo_group(path = 'data/')

benchmark = group.get('Drugcomb_ZIP')

predictions = {}
name = benchmark['name']
train_val, test = benchmark['train_val'], benchmark['test']
train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = 1)

print('Different results from permutations:')
tmp = {name:np.arange(59708)}
print('Perm 1:', group.evaluate(tmp))
tmp = {name:np.random.permutation(np.arange(59708))}
print('Perm 2:', group.evaluate(tmp))

print('Better results than MLP from a constant:')
tmp = {name:np.zeros(59708)-1.095}
print('Constant:', group.evaluate(tmp))

from sklearn.metrics import mean_absolute_error
print('An example of why certain permutations are much better:')
gt = np.arange(20)*10
pred = gt - 0.5
print('Correct MAE:', mean_absolute_error(gt, pred))
print('Permutation MAE 1:', mean_absolute_error(gt, np.random.permutation(pred)))
print('Permutation MAE 2:', mean_absolute_error(gt, np.random.permutation(pred)))
print('Permutation MAE 3:', mean_absolute_error(gt, np.random.permutation(pred)))
print('Permutation MAE 4:', mean_absolute_error(gt, np.random.permutation(pred)))
print('Permutation MAE 5:', mean_absolute_error(gt, np.random.permutation(pred)))

This outputs the following on my machine:

generating training, validation splits...
Different results from permutations:
Perm 1: {'drugcomb_zip': {'mae': 29854.6109218136}}
Perm 2: {'drugcomb_zip': {'mae': 29854.610926591966}}
Better results than MLP on leaderboard from a constant:
Constant: {'drugcomb_zip': {'mae': 4.01731854567091}}
An example of why certain permutations are much better:
Correct MAE: 0.5
Permutation MAE 1: 54.05
Permutation MAE 2: 81.9
Permutation MAE 3: 69.9
Permutation MAE 4: 67.0
Permutation MAE 5: 69.95

Note that I modified the code in the library to round to more decimal places. By default it is 3, which makes the function look permutation invariant (e.g. both 29854.6109218136 and 29854.610926591966 round to 29854.611).

Expected behavior The input to the evaluate method should require the permutation to be specified so it can be correctly compared to the ground truth.

Environment:

If I just missed the correct usage of the group.evaluate() method please let me know. Thanks!

kexinhuang12345 commented 1 year ago

Hi Carl, the assumed input for the y_pred in group.evaluate(y_pred) has the same order of the fixed test split in the test variable from the benchmark group. TDC fixes the test set for a fair comparison.

I suspect that the MLP is worse than the constant is due to the bad performance of MLP on this task.

kexinhuang12345 commented 1 year ago

Closing for now! Feel free to reopen if you have any further question!