median-research-group / LibMTL

A PyTorch Library for Multi-Task Learning
MIT License
1.94k stars 181 forks source link

AttributeError: 'Net' object has no attribute 'conv1' #72

Closed RENEK-bool closed 7 months ago

RENEK-bool commented 7 months ago

I encountered an error when I ran the QM9 sample code, which will work when arch is replaced by HPS. The environment is configured according to the instructions in the readme file.

(libmtl) [gdut@localhost qm9]$ python main.py --weighting GLS --arch Cross_stitch --dataset_path /home/gdut/Documents/ZJY/LibMTL/examples/qm9/dataset --gpu_id 0 --mode train --save_path /home/gdut/Documents/ZJY/LibMTL/examples/qm9/model

General Configuration: Mode: train Wighting: GLS Architecture: Cross_stitch Rep_Grad: False Multi_Input: False Seed: 0 Save Path: /home/gdut/Documents/ZJY/LibMTL/examples/qm9/model Load Path: None Device: cuda:0 Optimizer Configuration: optim: adam lr: 0.0001 weight_decay: 1e-05 Traceback (most recent call last): File "main.py", line 181, in main(params) File "main.py", line 156, in main QM9model = QM9trainer(task_dict=task_dict, File "main.py", line 145, in init super(QM9trainer, self).init(task_dict=task_dict, File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/trainer.py", line 92, in init self._prepare_model(weighting, architecture, encoder_class, decoders) File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/trainer.py", line 104, in _prepare_model self.model = MTLmodel(task_name=self.task_name, File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/trainer.py", line 101, in init super(MTLmodel, self).init(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/architecture/Cross_stitch.py", line 56, in init self.encoder = _transform_resnet_cross(self.encoder, task_name, device) File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/architecture/Cross_stitch.py", line 15, in init self.resnet_conv = nn.ModuleDict({task: nn.Sequential(encoder_list[tn].conv1, encoder_list[tn].bn1, File "/home/gdut/Documents/ZJY/LibMTL/LibMTL/architecture/Cross_stitch.py", line 15, in self.resnet_conv = nn.ModuleDict({task: nn.Sequential(encoder_list[tn].conv1, encoder_list[tn].bn1, File "/home/gdut/anaconda3/envs/libmtl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'Net' object has no attribute 'conv1'

Baijiong-Lin commented 7 months ago

Cross-stitch only supports ResNet-based encoders. The encoder in QM9 is a GNN network.

https://github.com/median-research-group/LibMTL/blob/b1ff34d1bc72a208ef4f42301e6021db42913653/LibMTL/architecture/Cross_stitch.py#L46

RENEK-bool commented 7 months ago

Okay, thank you.