Closed czstudio closed 10 months ago
also cannot work on GOODPCBA dataset
Hi @czstudio the original version of CIGA does not support sampling with continuous labels in regression tasks. There might exist some instant solutions such as converting the continuous labels into discrete labels or changing the sampling criteria according to the label distance (instead of the exact match).
$ goodtg --config_path final_configs/GOODPCBA/scaffold/covariate/CIGAv2.yaml This logger will substitute general print function
INFO: ----------------------------------- Task: train Thu Oct 5 23:19:49 2023 INFO: Load Dataset GOODPCBA DEBUG: 10/05/2023 11:19:51 PM : Dataset: {'train': GOODPCBA(262764), 'id_val': GOODPCBA(43792), 'id_test': GOODPCBA(43792), 'val': GOODPCBA(44019), 'test': GOODPCBA(43562), 'task': 'Binary classification', 'metric': 'Average Precision'} DEBUG: 10/05/2023 11:19:51 PM : Data(x=[21, 9], edge_index=[2, 46], edge_attr=[46, 3], y=[1, 128], smiles='CC1CCN(C(=O)CN2CC(C)Sc3ccccc32)CC1', idx=[1], scaffold='O=C(CN1CCSc2ccccc21)N1CCCCC1', domain_id=[1], env_id=[1]) INFO: Loading model... DEBUG: 10/05/2023 11:19:51 PM : Config model DEBUG: 10/05/2023 11:19:53 PM : Load training utils INFO: Epoch 0: 0%|░░░░░░░░░░░░░░░░░░░░| 0/8212 [00:00<?, ?it/s]tensor(2550, device='cuda:0') torch.Size([32, 300]) torch.Size([32, 128]) torch.Size([32, 128]) 0%|░░░░░░░░░░░░░░░░░░░░| 0/8212 [00:01<?, ?it/s] ERROR: 10/05/2023 11:19:54 PM - utils.py - line 87 : Traceback (most recent call last): File "/home/cz/miniconda3/envs/py38/bin/goodtg", line 33, in
sys.exit(load_entry_point('graph-ood', 'console_scripts', 'goodtg')())
File "/home/cz/code/GOOD-GOODv1/GOOD/kernel/main.py", line 69, in goodtg
main()
File "/home/cz/code/GOOD-GOODv1/GOOD/kernel/main.py", line 60, in main
pipeline.load_task()
File "/home/cz/code/GOOD-GOODv1/GOOD/kernel/pipelines/basic_pipeline.py", line 231, in load_task
self.train()
File "/home/cz/code/GOOD-GOODv1/GOOD/kernel/pipelines/basic_pipeline.py", line 113, in train
train_stat = self.train_batch(data, pbar) # train_stat是一个字典,包含loss
File "/home/cz/code/GOOD-GOODv1/GOOD/kernel/pipelines/basic_pipeline.py", line 74, in train_batch
loss = self.ood_algorithm.loss_calculate(raw_pred, targets, mask, node_norm, self.config)
File "/home/cz/code/GOOD-GOODv1/GOOD/ood_algorithms/algorithms/CIGA.py", line 81, in loss_calculate
assert self.rep_out.size(0)==targets[mask].size(0), print(mask.sum(),self.rep_out.size(),targets.size(),mask.size())
AssertionError: None