batch_size设置大于1时报错,如下,batch_size为4
File "train.py", line 326, in train
model.optimize_parameters()
File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 289, in optimize_parameters
self.backward_G() # calculate graidents for G
self.backward_G() # calculate graidents for G
File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G
File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G
self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024) self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024)
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
return self._call_impl(*args, *kwargs)
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward
return forward_call(*args, **kwargs)
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward
return F.cross_entropy(input, target, weight=self.weight,
return F.cross_entropy(input, target, weight=self.weight,
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: ValueErrorExpected input batch_size (1) to match target batch_size (4).
: Expected input batch_size (1) to match target batch_size (4).
batch_size设置大于1时报错,如下,batch_size为4 File "train.py", line 326, in train model.optimize_parameters() File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 289, in optimize_parameters self.backward_G() # calculate graidents for G self.backward_G() # calculate graidents for G File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024) self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024)
File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) return self._call_impl(*args, *kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward return forward_call(*args, **kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward return F.cross_entropy(input, target, weight=self.weight, return F.cross_entropy(input, target, weight=self.weight, File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ValueError: ValueErrorExpected input batch_size (1) to match target batch_size (4). : Expected input batch_size (1) to match target batch_size (4).