Open yan1617262965 opened 2 years ago
hello. Have you solved it? l want to know how you solve this problem.
Not yet, and have you solved it
你好。你解决了吗?我想知道你是如何解决这个问题的。
File "DEFT/src/lib/trainer.py", line 149 class ModleWithLoss(torch.nn.Module): def init(self, model, loss): super(ModleWithLoss, self).init() self.model = model self.loss = loss self.s_det = nn.Parameter(torch.ones(1)) self.s_id = nn.Parameter(torch.ones(1))
I am very interested in this work, I would like to reproduce it, but can you help me, I can run test.py file, but when I want to train yes, run train.py file will report an error, prompt as follows: yesyes /opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. /opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. /opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. /opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. Traceback (most recent call last): File "train.py", line 133, in
main(opt)
File "train.py", line 92, in main
log_dicttrain, = trainer.train(epoch, train_loader)
File "/home/buu/Yan/DEFT/src/lib/trainer.py", line 486, in train
return self.run_epoch("train", epoch, data_loader)
File "/home/buu/Yan/DEFT/src/lib/trainer.py", line 246, in run_epoch
output, loss, loss_stats = model_with_loss(batch)
File "/home/buu/anaconda3/envs/DEFT/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, *kwargs)
File "/home/buu/Yan/DEFT/src/lib/trainer.py", line 192, in forward
loss_stats["tot"] = torch.exp(-self.s_det) loss_stats["tot"] + torch.exp(-self.s_id) * loss_matching + (self.s_det + self.s_id)
File "/home/buu/anaconda3/envs/DEFT/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in getattr
type(self).name, name))
AttributeError: 'ModleWithLoss' object has no attribute 's_det'