quiqi / relu_kan

52 stars 10 forks source link

ReluKAN seems to only handle single node inputs? #1

Open huangst21 opened 4 weeks ago

huangst21 commented 4 weeks ago

An error is reported when the ReluKAN input has more than one node: RuntimeError: The size of tensor a (30) must match the size of tensor b (8) at non-singleton dimension 1

Muzixingxi commented 3 weeks ago

Excuse me, has this problem been solved?

quiqi commented 3 weeks ago

Will exp_fitting.py work properly, this script has a corresponding multi-input instance

reeyarn commented 1 week ago

repair_relu_kan.py.zip I have figured out the solution. The forward method of the layer class ReLUKANLayer.forward(...) has been modified

Attached is a complete example, the input has 3 dimensions and output 2 dimensions.

class ReLUKANLayer(nn.Module):
    def __init__(self, input_size: int, g: int, k: int, output_size: int, train_ab: bool = True):
        super().__init__()
        self.g, self.k, self.r = g, k, 4*g*g / ((k+1)*(k+1))
        self.input_size, self.output_size = input_size, output_size
        phase_low = np.arange(-k, g) / g
        phase_height = phase_low + (k+1) / g
        self.phase_low = nn.Parameter(torch.Tensor(np.array([phase_low for i in range(input_size)])),
                                      requires_grad=train_ab)
        self.phase_height = nn.Parameter(torch.Tensor(np.array([phase_height for i in range(input_size)])),
                                         requires_grad=train_ab)
        self.equal_size_conv = nn.Conv2d(1, output_size, (g+k, input_size))

    def forward(self, x):
        # Expand dimensions of x to match the shape of self.phase_low
        x_expanded = x.unsqueeze(2).expand(-1, -1, self.phase_low.size(1))

        # Perform the subtraction with broadcasting
        x1 = torch.relu(x_expanded - self.phase_low)
        x2 = torch.relu(self.phase_height - x_expanded)

        # Continue with the rest of the operations
        x = x1 * x2 * self.r
        x = x * x
        x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
        x = self.equal_size_conv(x)
        #x = x.reshape((len(x), self.output_size, 1)) 
        x = x.reshape((len(x), self.output_size)) 
        return x