xypu98 / CWSAM

42 stars 0 forks source link

batch_size #12

Open hatmore opened 3 months ago

hatmore commented 3 months ago

batch_size设置大于1时报错,如下,batch_size为4 File "train.py", line 326, in train model.optimize_parameters() File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 289, in optimize_parameters self.backward_G() # calculate graidents for G self.backward_G() # calculate graidents for G File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G File "/public/home/hatmore/project/CWSAM-master/models/sam.py", line 280, in backward_G self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024) self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) #(1,4,1024,1024)

File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) return self._call_impl(*args, *kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward return forward_call(*args, **kwargs) File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward return F.cross_entropy(input, target, weight=self.weight, return F.cross_entropy(input, target, weight=self.weight, File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy File "/public/home/hatmore/miniconda3/envs/CWSAM/lib/python3.8/site-packages/torch/nn/functional.py", line 3059, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ValueError: ValueErrorExpected input batch_size (1) to match target batch_size (4). : Expected input batch_size (1) to match target batch_size (4).

xypu98 commented 3 months ago

torch._C._nn.cross_entropy_loss中input, target的维度要一样,一对一计算loss。由于我开展的实验都是batchsize为1,没有处理多个batchsize的情况。可能需要修改代码加入循环或维度对齐等操作使得loss计算维度匹配