Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

Allow interchangeability of KANLinear with nn.Linear #6

Closed akaashdash closed 1 month ago

akaashdash commented 1 month ago

Currently KANLinear strictly checks the dimensions are (batch_size, in_features). This does not allow compatibility with other input shapes where the feature sizes of the same. On the other hand, nn.Linear allows for any shape where the last dimension matches feature shape, eg: (, in_features), and it will reshape it to match and reshape back to the same (, out_features) on returning. Because of this difference, KANLinear and nn.Linear are currently not interchangeable, prevent people from swapping MLPs with KANs in existing code to explore differences in the two.

This change would allow for better shape compatibility, and therefore allow people to interchange KANLinear (or KAN) with nn.Linear layers.

Blealtan commented 1 month ago

Thank you for that! Just about to write that myself after reading the issue but before seeing your PR. And sorry for the delay. LGTM.