CaptainEven / MCMOT

Real time one-stage multi-class & multi-object tracking based on anchor-free detection and ReID
MIT License
383 stars 82 forks source link

where() missing 2 required positional argument: "input", "other" #53

Closed Chrispaoge closed 3 years ago

Chrispaoge commented 3 years ago

您好,在跑train.py的时候遇到了这个错误 : File "D:\code\pytorchCode\MCMOT\src\lib\trains\mot.py", line 198, in forward inds = torch.where(cls_id_map == cls_id), TypeError: where() missing 2 required positional argument: "input", "other" 在计算ReID loss这一块,inds = torch.where(cls_id_map == cls_id),我应该如何来改呢?来此where函数里,我应该填入哪两个参数呢?谢谢

CaptainEven commented 3 years ago

@Chrispaoge 首先,你要确保你的opts.py参数设置正确(cls_ids, cls2id, id2cls等),然后关注训练的log, 比如hm的维数等。祝你好运!

Chrispaoge commented 3 years ago

@CaptainEven 您好,我觉得应该是我没表达清楚,您的代码在计算ReID损失时,用到了torch.where()函数,在这里您只填入了“ cls_id_map == cls_id“这一判断条件,即 image

torch.where()还缺少填入输入张量和输出张量,阅读代码后我还是不确定应该填入哪两个。不知道是不是我pytorch版本的问题,我用的pytorch1.1,已跑通您的FairMOTVehicle,但是我看了下pytorch1.1和高版本的pytorch中torch.where()函数并没有变化 cls_ids等参数都已设置好,但是torch.where()这个问题不解决,训练跑不起来

Traceback (most recent call last): File "D:/code/pytorchCode/MCMOT/src/train.py", line 162, in run(opt) File "D:/code/pytorchCode/MCMOT/src/train.py", line 121, in run log_dicttrain, = trainer.train(epoch, train_loader) File "D:\code\pytorchCode\MCMOT\src\lib\trains\base_trainer.py", line 162, in train return self.run_epoch('train', epoch, data_loader) File "D:\code\pytorchCode\MCMOT\src\lib\trains\base_trainer.py", line 97, in run_epoch output, loss, loss_stats = model_with_loss.forward(batch) File "D:\code\pytorchCode\MCMOT\src\lib\trains\base_trainer.py", line 24, in forward loss, loss_stats = self.loss.forward(outputs=outputs, batch=batch) File "D:\code\pytorchCode\MCMOT\src\lib\trains\mot.py", line 198, in forward inds = torch.where(cls_id_map == cls_id) TypeError: where() missing 2 required positional argument: "input", "other"

CaptainEven commented 3 years ago

@Chrispaoge 据我所知,torch.where函数跟numpy.where函数几乎一致, image image

并没有要求一定填入输入张量和输出张量,我用的是第二种方法,通过where获取索引,因此不需要填入x, y张量

Ronales commented 2 years ago

this problem may be caused by pytorch version, you can modify follow this:

inds = torch.where(cls_id_map == cls_id)

inds = (cls_id_map[:] == cls_id).nonzero().t()