DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.42k stars 199 forks source link

example for tasks.MultipleBinaryClassification #233

Open GZ82 opened 10 months ago

GZ82 commented 10 months ago

Hi Anyone can provide an example for this task: tasks.MultipleBinaryClassification the setting for task argument is very different to that of task. PropertyPrediction the document () mentioned it is optional, or a list of int, tried both optional or list of [1,0,...]

task = tasks.MultipleBinaryClassification(
    model, task=list(data_raw.one_or_zero), 
    criterion="bce", metric=('auprc@micro', 'f1_max'), 
    num_mlp_layer=1, normalization=False
)

my dataset.tasks is ["one_or_zero"] when run:

solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                        gpus=[0], batch_size=1024)

get following errors:

12:39:33   Preprocess training set
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb](https://vscode-remote+ssh-002dremote-002bmlzhen.vscode-resource.vscode-cdn.net/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb) Cell 18 line 4
      [2](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) counter = 1
      [3](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2) for i in range(counter):
----> [4](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3)     solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
      [5](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4)                         gpus=[0], batch_size=1024) #
      [6](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5)     solver.train(num_epoch=10)
      [7](vscode-notebook-cell://ssh-remote%2Bmlzhen/home/ec2-user/projects/ts_2023/getrightmol/src/getrightmol/models/torchdrug.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6)     results.append(solver.evaluate("valid")) # evaluate based on: "train", "valid", "test"

File [/opt/conda/envs/torchdrug/lib/python3.10/site-packages/decorator.py:232](https://vscode-remote+ssh-002dremote-002bmlzhen.vscode-resource.vscode-cdn.net/opt/conda/envs/torchdrug/lib/python3.10/site-packages/decorator.py:232), in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File [/opt/conda/envs/torchdrug/lib/python3.10/site-packages/torchdrug/core/core.py:296](https://vscode-remote+ssh-002dremote-002bmlzhen.vscode-resource.vscode-cdn.net/opt/conda/envs/torchdrug/lib/python3.10/site-packages/torchdrug/core/core.py:296), in _Configurable.__new__.<locals>.wrapper(init, self, *args, **kwargs)
    294     config.pop(k)
    295 self._config = dict(config)
--> 296 return init(self, *args, **kwargs)

File [/opt/conda/envs/torchdrug/lib/python3.10/site-packages/torchdrug/core/engine.py:92](https://vscode-remote+ssh-002dremote-002bmlzhen.vscode-resource.vscode-cdn.net/opt/conda/envs/torchdrug/lib/python3.10/site-packages/torchdrug/core/engine.py:92), in Engine.__init__(self, task, train_set, valid_set, test_set, optimizer, scheduler, gpus, batch_size, gradient_interval, num_worker, logger, log_interval)
     89 # TODO: more elegant implementation
     90 # handle dynamic parameters in optimizer
     91 old_params = list(task.parameters())
---> 92 result = task.preprocess(train_set, valid_set, test_set)
...
--> 270     values.append(data["targets"][self.task_indices])
    271 values = torch.stack(values, dim=0)    
    273 if self.reweight:

KeyError: 'targets'
Oxer11 commented 10 months ago

Hi, tasks.MultipleBinaryClassification is designed for protein function classification tasks. You can refer to the GearNet repo and this tutorial.

GZ82 commented 10 months ago

@Oxer11 Hi thanks very much for your prompt reply! I want to predict toxicity of chemicals, say 1 toxic 0 non-toxic, i.e., a binary classification, and inputs are SMILES does tasks.MultipleBinaryClassification fit in this application? or I just use tasks.PropertyPrediction and set num_class=1? sorry I do not understand what does task=[_ for _ in range(len(dataset.tasks))] this mean in

task = tasks.MultipleBinaryClassification(gearnet, graph_construction_model=graph_construction_model, num_mlp_layer=3,
                                          task=[_ for _ in range(len(dataset.tasks))], criterion="bce", metric=["auprc@micro", "f1_max"])

in your tutorial