yifannnwu / NODEO-DIR

[CVPR 2022] NODEO: A Neural Ordinary Differential Equation Based Optimization Framework for Deformable Image Registration
21 stars 6 forks source link

Image dimension/kernel problems #2

Closed wangtongy closed 9 months ago

wangtongy commented 11 months ago

Hi,

I am trying to use the registration code to register two MR images (whose dimension is 720,722,400). However, I run into the error message"

Traceback (most recent call last): File "/NODEO-DIR/Registration.py", line 197, in main(config) File "/NODEO-DIR/Registration.py", line 15, in main df, df_with_grid, warped_moving = registration(config, device, moving, fixed) File "/NODEO-DIR/Registration.py", line 62, in registration all_phi = ode_train(grid, Tensor(np.arange(config.time_steps)), return_whole_sequence=True) File "/.conda/envs/pytorch_tongyao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, kwargs) File "/NODEO-DIR/NeuralODE.py", line 180, in forward z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func, self.ode_solve, self.STEP_SIZE) File "/NODEO-DIR/NeuralODE.py", line 86, in forward z0 = ode_solve(z0, torch.abs(t[i_t + 1] - t[i_t]), func, STEP_SIZE) File "/NODEO-DIR/NeuralODE.py", line 41, in Euler z = z + step_size f(z) File "/.conda/envs/pytorch_tongyao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(input, kwargs) File "/NODEO-DIR/Network.py", line 131, in forward x = self.relu(self.lin1(x)) File "/.conda/envs/pytorch_tongyao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File "/.conda/envs/pytorch_tongyao/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x32256 and 864x16)

Could you help me with this problem?

Thanks! Cornelia

yifannnwu commented 9 months ago

Thanks for your question. Please consider updating self.lin1 = nn.Linear(864, self.bs, bias=bias) in [Network.py] to fit your input size.