Open sctm002 opened 2 years ago
说明你将输入送进模型时,可能按照字典的形式,常见于transformers中,如:
inputs = {
"input_ids": ...,
"attention_mask": ...,
"token_type_ids": ...,
}
outputs = model(**inputs)
再看下源码的处理:
def scatter(self, inputs, kwargs, device_ids):
# 从inputs第一个输入中获取bsz
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
所以当你上面输入过来的时候,Inputs就是个空的元组,肯定不work,可以将scatter获取bsz的代码改成我这个:
def scatter(self, inputs, kwargs, device_ids):
if len(inputs) > 0:
bsz = inputs[0].size(self.dim)
elif kwargs:
bsz = list(kwargs.values())[0].size(self.dim)
else:
raise ValueError("You must pass inputs to the model!")
num_dev = len(self.device_ids)
...
bsz = inputs[0].size(self.dim) IndexError: tuple index out of range 原版是这样写的: model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) 按这个版本的介绍这样写: model = BalancedDataParallel(1,model, dim=0).cuda() 就一直报错。 这个的说明内容也太少了吧。 不知道从何排错。