locuslab / torchdeq

Modern Fixed Point Systems using Pytorch
MIT License
78 stars 9 forks source link

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list. #2

Closed georgewanglz2019 closed 5 months ago

georgewanglz2019 commented 6 months ago

Hi, I try to use your framework to build my own DEQ model with some simple fully connected neural network but I keep getting an error that I can't fix. Can you help resolve these errors? Maybe there is something wrong with the definition of neural network?

Many thanks in advance!

Code:

import torch
import torch.nn as nn

from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
import os

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class FCNN(nn.Module):
    def __init__(self, input_dim, neurons_per_layer, output_dim):
        super(FCNN, self).__init__()

        self.linear1 = nn.Linear(input_dim, neurons_per_layer)
        self.linear2 = nn.Linear(neurons_per_layer,neurons_per_layer)
        self.out = nn.Linear(neurons_per_layer, output_dim)

    def forward(self, x, z):
        # z: links' flow  x: context

        zx = torch.cat((z, x))

        z_processed = F.relu(self.linear1(zx))
        z_processed = F.relu(self.linear2(z_processed))

        return F.relu(self.out(z_processed))

if __name__ == '__main__':

    seed = 1
    torch.manual_seed(seed)

    fcnn = FCNN(input_dim=15, neurons_per_layer=20, output_dim=10)
    print(fcnn)

    x_, z_ = torch.ones(5), torch.ones(10)  # z0
    for i in range(10):
        z_ = fcnn(x_, z_)
        print('z:', z_)

    # Let's try a multi-variable DEQ!
    deq = get_deq(f_solver='broyden', f_max_iter=20, f_tol=1e-6)

    x_, z_ = torch.ones(5), torch.zeros(10)  # z0
    # f = lambda z: fcnn(x, z)
    z_out, info = deq(fcnn, (x_, z_))

Errors: Traceback (most recent call last): File "E:\PycharmProjects\torch_gpu\DEQ4TA\FCNN.py", line 55, in zout, info = deq(fcnn, (x, z_)) File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\core.py", line 592, in forward deq_func, z_star = deq_decorator(func, z_star, no_stat=self.no_stat) File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 139, in deq_decorator return func, func.list2vec(z_init) File "C:\Users\Leizhen.conda\envs\torch_gpu\lib\site-packages\torchdeq\utils\layer_utils.py", line 60, in list2vec return torch.cat(z_list, dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 10 for tensor number 1 in the list.

georgewanglz2019 commented 6 months ago

modified:

x, z = torch.ones(5), torch.zeros(10) # z0 f = lambda z: fcnn(x_, z) zout, info = deq(f, z)

maybe because I should define a function f which only have one input "z", now it works!

Gsunshine commented 5 months ago

Hi @georgewanglz2019 ,

Thank you for your interest in TorchDEQ! I apologize for my late response. (Busy working on the Consistency Models Made Easy blog post. Recommend it also. :)

Yeah. Indeed, you should pass a Python functor that only takes the equilibrium tensor z as input into the DEQ module. The DEQ module will solve for the fixed point z* and track its gradient.

If you have a joint equilibrium formulation, for example, [a, b, c] = f([a, b, c], x), you can also code it like this.

Let me know if you have further questions!

Thanks, Zhengyang