mims-harvard / TDC

Therapeutics Commons (TDC-2): Multimodal Foundation for Therapeutic Science
https://tdcommons.ai
MIT License
1.02k stars 174 forks source link

Support multiple instances in cold split #126

Closed jannisborn closed 2 years ago

jannisborn commented 3 years ago

Describe the problem The current cold split can only split by instances of one modality, but not on multiple modalities. For example, the DrugRes dataset cannot be split such that test samples contain drugs and cell lines that are both unseen to the model after training.

Describe the solution you'd like A functionality like this:

split = data.get_split(method = 'cold_split', column_names = ['Drug', 'Cell Line'])

that works for any multi-instance dataset and can split on multiple columns.

Is this a feature you have conceived of, but intentionally refrained from implementing? Or would this be a valuable contribution to the package?

kexinhuang12345 commented 3 years ago

Hi Jannis, thanks for the issue. I think it would be a great addition. Do you have the bandwidth to contribute? If not, we would also be glad to implement it. Let us know, thanks!

jannisborn commented 3 years ago

Sure let's see where we can go together. I need a bit of guidance though.
Is it this bit of code that requires update: https://github.com/mims-harvard/TDC/blob/main/tdc/utils/split.py#L29 ?

Also, keep in mind that this split implies that some samples will remain unused (i.e., neither assigned to train nor to test data).

kexinhuang12345 commented 3 years ago

Glad to hear. Yes, I think the simplest solution would be to create a separate function called create_fold_setting_cold_multi. The reason is that it is the case that multiple classes use the function (create_fold_setting_cold), and backward compatibility would be tricky. After you have implemented the utils, you can then add to https://github.com/mims-harvard/TDC/blob/0e2611e352684cefd35454213f70afd0e9e6db77/tdc/multi_pred/bi_pred_dataset.py#L140 to call this function. Maybe by setting method='cold_split_multi'?

This class is what the DrugRes dataloader inherits from, i.e. the base class for all multi-instance dataset with two entities pair. Let me know if you have any questions!