Link-Li / Balanced-DataParallel

这里是改进了pytorch的DataParallel, 用来平衡第一个GPU的显存使用量
231 stars 51 forks source link

用原版的就可以,用这个版本就一直报错 IndexError: tuple index out of range #15

Open sctm002 opened 2 years ago

sctm002 commented 2 years ago

bsz = inputs[0].size(self.dim) IndexError: tuple index out of range 原版是这样写的: model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) 按这个版本的介绍这样写: model = BalancedDataParallel(1,model, dim=0).cuda() 就一直报错。 这个的说明内容也太少了吧。 不知道从何排错。

sherlcok314159 commented 1 year ago

说明你将输入送进模型时,可能按照字典的形式,常见于transformers中,如:

inputs = {
    "input_ids": ...,
    "attention_mask": ...,
    "token_type_ids": ...,
}
outputs = model(**inputs)

再看下源码的处理:

def scatter(self, inputs, kwargs, device_ids):
    # 从inputs第一个输入中获取bsz
    bsz = inputs[0].size(self.dim)
    num_dev = len(self.device_ids)

所以当你上面输入过来的时候,Inputs就是个空的元组,肯定不work,可以将scatter获取bsz的代码改成我这个:

def scatter(self, inputs, kwargs, device_ids):
    if len(inputs) > 0:
        bsz = inputs[0].size(self.dim)
    elif kwargs:
        bsz = list(kwargs.values())[0].size(self.dim)
    else:
        raise ValueError("You must pass inputs to the model!")
    num_dev = len(self.device_ids)
    ...