engichang1467 / KAN-Mixer

Source code of KAN-Mixer
3 stars 0 forks source link

Solved dimension errors #2

Closed engichang1467 closed 3 months ago

engichang1467 commented 3 months ago

I have finally resolved the tensor dimension error!

Problem

The tensor dimension after the transformation is not compatible with the original KANLinear format. The KANLinear class expects a 2D input, but I want to modify it to accept a 3D tensor.

Solution

I don't have to change anything with KAN. I just need to modify the forward method to flatten the last two dimensions of the input, and then reshape the output to have the same shape as the input.

class KANLinear(nn.Module):
    # ... (keep the existing code here)

    def forward(self, x: torch.Tensor):
        # Save the original shape
        original_shape = x.shape

        # Flatten the last two dimensions of the input
        x = x.contiguous().view(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.spline_weight.view(self.out_features, -1),
        )

        # Apply the linear transformation
        output = base_output + spline_output

        # Reshape the output to have the same shape as the input
        output = output.view(*original_shape[:-1], -1)

        return output