JackEasson / SLPNet_pytorch

SLPNet: Towards End-to-End Car License Plates Detection and Recognition Using Lightweight CNN
56 stars 13 forks source link

多GPU训练 #14

Open Gavin-zsr opened 2 years ago

Gavin-zsr commented 2 years ago

请问代码需要哪些修改才能使模型使用多卡训练呢? 我在train.py文件中直接加入了 device = "cuda:0" if args.cuda: model = model.cuda(device) model = torch.nn.DataParallel(model, device_ids=[0,1,2]) 但是在模型验证时报错 Traceback (most recent call last): File "train.py", line 538, in main(parser.parse_args()) File "train.py", line 512, in main model = train(args, model, device) # Train decoder File "train.py", line 317, in train obj_num_list, scores_tensor, coordinates_tensor = model(images, mode1='det_only', mode2='eval') File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, *kwargs) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward return self.gather(outputs, self.output_device) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather return gather(outputs, output_device, dim=self.dim) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather res = gather_map(outputs) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map return type(out)(map(gather_map, zip(outputs))) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map return type(out)(map(gather_map, zip(outputs))) File "/home/guest/fhc/anaconda3/envs/slpnet/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map return type(out)(map(gather_map, zip(outputs))) TypeError: zip argument #1 must support iteration

不使用多GPU训练的话,感觉这个模型要训练很长时间

JackEasson commented 2 years ago

验证环节的后处理部分,主要在进行obj_num_list, scores_tensor, coordinates_tensor = model(images, mode1='det_only', mode2='eval')时涉及一些张量从gpu转到cpu的操作,在多GPU情况下会遇到问题,你可以去掉验证部分最后单独验证,或者在验证环节使用单GPU验证。