Open miandai opened 1 year ago
@miandai I can't see the difference. Can you elaborate?
Sorry to my poor English.
Pay attention to the tab num. The first two "if" condition should be serial, not inclusion.
"if self.use_distributed: " should not in "if self._model.device != self.device:"
Could you please maybe explain in more detail, in Chinese if you like? I'd be really curious to know if this is a bug, that we have to fix it.
def model(self):
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu]
)
else:
self._wrapped_model = self._model
return self._wrapped_model
以上是当前 master 分支 lavis/runners/runner_base.py 代码。如果没有进入 if self._model.device != self.device 内,会直接返回 self._wrapped_model,即 None,这会引起后续调用 self.model.train() 时报错:AttributeError: 'NoneType' object has no attribute 'train'
个人推测,这里如果没有进入 if self._model.device != self.device 内,是否应该继续执行 if self.use_distributed 条件的判断,继续给 self._wrapped_model 赋值,然后再返回 self._wrapped_model。
更正后的代码如下,跟以上代码的区别在于从 if self.use_distributed: 到 self._wrapped_model = self._model 这几行都往前缩进了一个 tab。
def model(self):
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu]
)
else:
self._wrapped_model = self._model
return self._wrapped_model
对项目不熟,不知猜测是否正确。
It works for me. THX.
There is some code may be wrong in lavis/runners/runner_base.py:
should be:
Or it will be error:
File "/home/Hamburg/github/LAVIS/train.py", line 103, in
main()
File "/home/Hamburg/github/LAVIS/train.py", line 99, in main
runner.train()
File "/home/Hamburg/github/LAVIS/lavis/runners/runner_base.py", line 369, in train
train_stats = self.train_epoch(cur_epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/Hamburg/github/LAVIS/lavis/runners/runner_base.py", line 426, in train_epoch
self.model.train()
^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'train'