I have finally resolved the tensor dimension error!
Problem
The tensor dimension after the transformation is not compatible with the original KANLinear format. The KANLinear class expects a 2D input, but I want to modify it to accept a 3D tensor.
Solution
I don't have to change anything with KAN. I just need to modify the forward method to flatten the last two dimensions of the input, and then reshape the output to have the same shape as the input.
class KANLinear(nn.Module):
# ... (keep the existing code here)
def forward(self, x: torch.Tensor):
# Save the original shape
original_shape = x.shape
# Flatten the last two dimensions of the input
x = x.contiguous().view(-1, self.in_features)
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.spline_weight.view(self.out_features, -1),
)
# Apply the linear transformation
output = base_output + spline_output
# Reshape the output to have the same shape as the input
output = output.view(*original_shape[:-1], -1)
return output
I have finally resolved the tensor dimension error!
Problem
The tensor dimension after the transformation is not compatible with the original KANLinear format. The KANLinear class expects a 2D input, but I want to modify it to accept a 3D tensor.
Solution
I don't have to change anything with KAN. I just need to modify the forward method to flatten the last two dimensions of the input, and then reshape the output to have the same shape as the input.