ethanwharris / STAWM

Code for the paper 'A Biologically Inspired Visual Working Memory for Deep Networks'
https://arxiv.org/abs/1901.03665
21 stars 2 forks source link

RuntimeError: #2

Closed Hirohong21 closed 5 years ago

Hirohong21 commented 5 years ago

hello sir! I tried to run your code,but there are some bug and do not know how to solve,here are the Traceback :

MNIST: File "/home/sdc/Desktop/STAWM-master/mnist_class_28.py", line 121, in run(8, 512) File "/home/sdc/Desktop/STAWM-master/mnist_class_28.py", line 115, in run trial.run(200) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 121, in wrapper res = func(self, *args, kwargs) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 703, in run final_metrics = self._fit_pass(state)[torchbearer.METRICS] File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 276, in wrapper res = func(self, *args, *kwargs) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 768, in _fit_pass state[torchbearer.OPTIMIZER].step(closure) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/optim/adam.py", line 58, in step loss = closure() File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 752, in closure state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(input, kwargs) File "/home/sdc/Desktop/STAWM-master/mnist_class_28.py", line 83, in forward x = self.memory.glimpse(x, image) File "/home/sdc/Desktop/STAWM-master/memory.py", line 112, in glimpse x = self.locator(pose, image) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/home/sdc/Desktop/STAWM-master/modules.py", line 29, in forward theta = theta.view(x.size(0), 2, 4) RuntimeError: shape '[128, 2, 4]' is invalid for input of size 768

Celeba: File "/home/sdc/Desktop/STAWM-master/celeba_draw_32.py", line 292, in run(8, 32, 128, 0, device='cuda') File "/home/sdc/Desktop/STAWM-master/celeba_draw_32.py", line 288, in run trial.run(100) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 121, in wrapper res = func(self, *args, kwargs) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 703, in run final_metrics = self._fit_pass(state)[torchbearer.METRICS] File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 276, in wrapper res = func(self, *args, *kwargs) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 768, in _fit_pass state[torchbearer.OPTIMIZER].step(closure) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/optim/adam.py", line 58, in step loss = closure() File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torchbearer/trial.py", line 752, in closure state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(input, kwargs) File "/home/sdc/Desktop/STAWM-master/celeba_draw_32.py", line 146, in forward x, inverse = self.memory.glimpse(x, image) File "/home/sdc/Desktop/STAWM-master/memory.py", line 112, in glimpse x = self.locator(pose, image) File "/home/sdc/Desktop/test/test/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/home/sdc/Desktop/STAWM-master/modules.py", line 29, in forward theta = theta.view(x.size(0), 2, 4) RuntimeError: shape '[128, 2, 4]' is invalid for input of size 192

ethanwharris commented 5 years ago

Seems to be working fine for me. The line theta = theta.view(x.size(0), 2, 4) is actually theta = theta.view(x.size(0), 2, 3) in the codebase. Changing it to that should work