divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.85k stars 281 forks source link

TypeError: train() missing 3 required positional arguments: 'data_loader', 'optimizer', and 'epochs' #171

Open yhliu2022 opened 1 year ago

yhliu2022 commented 1 year ago

Test subgraphx example: explainer = SubgraphX(grace, num_classes=4, device=device, explain_graph=False, reward_method='nc_mc_l_shapley')

then get this error

TypeError Traceback (most recent call last) Input In [19], in <cell line: 1>() ----> 1 explainer = SubgraphX(grace, num_classes=4, device=device, 2 explain_graph=False, reward_method='nc_mc_l_shapley')

File ~\DIG\dig\xgraph\method\subgraphx.py:636, in SubgraphX.init(self, model, num_classes, device, num_hops, verbose, explain_graph, rollout, min_atoms, c_puct, expand_atoms, high2low, local_radius, sample_num, reward_method, subgraph_building_method, save_dir, filename, vis) 629 def init(self, model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False, 630 explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0, 631 expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', 632 subgraph_building_method='zero_filling', save_dir: Optional[str] = None, 633 filename: str = 'example', vis: bool = True): 635 self.model = model --> 636 self.model.eval() 637 self.device = device 638 self.model.to(self.device)

File ~\anaconda3\envs\tf\lib\site-packages\torch\nn\modules\module.py:1926, in Module.eval(self) 1910 def eval(self: T) -> T: 1911 r"""Sets the module in evaluation mode. 1912 1913 This has any effect only on certain modules. See documentations of (...) 1924 Module: self 1925 """ -> 1926 return self.train(False)

TypeError: train() missing 3 required positional arguments: 'data_loader', 'optimizer', and 'epochs'

Oceanusity commented 1 year ago

Hello, would you mind providing more details about the model used here? It seems like the error comes from the command self.model.eval() from the provided Traceback.