ShichenLiu / CondenseNet

CondenseNet: Light weighted CNN for mobile devices
MIT License
694 stars 131 forks source link

Dropping issue with pytorch v0.4 #14

Closed Cadene closed 6 years ago

Cadene commented 6 years ago

See: https://github.com/ShichenLiu/CondenseNet/blob/3b4398ed1987f6f7c891d81a470578dcc5c5562c/layers.py#L88

Weird stuff in the Pytorch API:

self._mask[i::self.groups, d, :, :].fill_(0)

... does not fill in place. So you must do:

self._mask[i::self.groups, d, :, :] = self._mask[i::self.groups, d, :, :].fill_(0)

https://github.com/pytorch/pytorch/issues/2599#issuecomment-326775742

lvdmaaten commented 6 years ago

Are you sure this is still the case in PyTorch v4? The following seems to work as expected for me:

>>> mask = torch.randn(6, 6)
>>> mask[2:4, 2].fill_(1)
tensor([ 1.,  1.])
>>> mask
tensor([[ 1.2523,  0.0436,  0.2293,  0.4972,  0.3498, -0.5289],
        [ 0.2715, -0.6395, -0.4040, -0.0956,  2.0621, -0.2742],
        [-1.1218, -0.6558,  1.0000,  0.8173, -0.9228, -0.5790],
        [ 0.5054, -0.1156,  1.0000, -2.0711, -0.4015, -0.4124],
        [-0.3459, -1.3607, -1.2971,  0.8082, -0.3922, -0.9164],
        [-0.1380,  2.4906,  0.1439,  0.5443, -0.5832, -0.0053]])
>>> torch.__version__
'0.4.0'
Cadene commented 6 years ago

Sorry for not having provided a way to reproduce the issue.

>>> mask = torch.randn(6,6)
>>> d = torch.tensor(4)
>>> mask[:,d].fill_(0)
 0
 0
 0
 0
 0
 0
[torch.FloatTensor of size (6,)]

>>> mask
-0.3400  0.9750 -2.9300  0.9474  0.7951  0.3728
 0.8730  0.8990  0.7515  0.2489  0.4229 -0.1711
 1.1368 -0.0591 -0.6083  0.1187 -1.4594  1.0921
-0.1175 -0.4797  0.4929  0.0196 -1.0943  1.2663
 1.0918 -0.1249 -1.0392 -0.1859  0.2909 -2.4262
 0.4422  1.2017  0.9938 -0.2634  0.3418  0.0506
[torch.FloatTensor of size (6,6)]

>>> torch.__version__
'0.4.0a0+f8270c0'
lvdmaaten commented 6 years ago

This is so weird!

>>> mask = torch.randn(6,6)
>>> d = torch.tensor(4)
>>> mask[:,d].fill_(0)
tensor([ 0.,  0.,  0.,  0.,  0.,  0.])
>>> mask
tensor([[ 0.4902,  0.4365,  0.3559, -1.1388,  0.0000, -0.0941],
        [-1.3191, -1.6433,  0.0495,  0.8033,  0.0000, -0.0540],
        [-0.2329,  1.8308,  1.1086,  0.7165,  0.0000,  2.1508],
        [ 1.3866, -2.4296, -0.4366,  0.1136,  0.0000, -0.5286],
        [-0.2129,  0.0132,  0.2015, -1.8690,  0.0000,  1.5336],
        [ 0.0350, -0.9614, -0.6592,  0.0031,  0.0000,  0.6525]])
>>> torch.__version__
'0.4.0'
Cadene commented 6 years ago

Oh! :o