KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
13.59k stars 1.19k forks source link

Continual learning with KAN #227

Open lukmanulhakeem97 opened 1 month ago

lukmanulhakeem97 commented 1 month ago

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

ASCIIJK commented 1 month ago

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have made the experiments. It seems that KAN is hard to learn 2D input data without forgetting. Specifically, we use a mixed 2D Gaussian distribution with 5 peaks to construct a CL tasks, which shows as bellow: Ground_task5 And the model learns each peak with 50,000 data points. For exemple, the data points of first task is showed as bellow: Pred_task0 Then, we get the results after 5 tasks: Pred_task4 This forgetting issue occurrs in each task, such as task 1: Pred_task1 PS: We use the model: "model = KAN(width=[2, 16, 1], grid=5, k=6, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)". And we have made sure that the loss is down to zero at each task. So you can find a perfect peak as the same as training data. We think that KAN maybe hard to learn the high-dimensional data without forgetting?

fangkuoyu commented 1 month ago

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have tried MNIST for continual learning, but I haven't obtained any positive results yet.

There are three stages in the process:

Stage_1: Train MNIST characters (0,1,2,3) from the train set, and test MNIST characters (0,1,2,3) from the test set; (establishing a baseline)

Stage_2: Train MNIST characters (4,5,6) and test MNIST characters (0,1,2,3,4,5,6); (with the hope that the model will memorize the results of Stage_1)

Stage_3: Train MNIST characters (7,8,9) and test MNIST characters (0,1,2,3,4,5,6,7,8,9); (with the hope that the model will memorize the results of Stage_1 and Stage_2)

I have tried two approaches:

Method_1: The original image (28x28) is resized to (7x7) and then flattened to (49). The KAN size is (49, 10, 10) with grid =3 and k=3. The training process is the same as Tutorial Example 7. Ref

Method_2: The original image (28x28) is mapped to (64) by nn.Linear(28*28,64) under PyTorch. The KAN size is (64, 16, 10) with grid=3 and k=3. The training process is the same as PyTorch Training. Ref

Roughly speaking both methods can achieve train accuracy > 90% in all three stages, but test accuracy degrades as (Stage 1 > 90%, Stage 2 ~ 40%, Stage 3 ~ 20%). I have also tried to change the grid size up to 100, but no significant improvement on test accuracy.

I am wondering if any width/grid/k setting under a memory/computation budget could reach a better accuracy of continual learning on MNIST.

Bytheway, some implementations of conv KAN as layer-drop-in-replacement don't provide the setting of 'bias_trainable=False, sp_trainable=False, sb_trainable=False' which limits the study of continual learning, e.g., on CIFAR-10.

ASCIIJK commented 1 month ago

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have tried MNIST for continual learning, but I haven't obtained any positive results yet.

There are three stages in the process:

Stage_1: Train MNIST characters (0,1,2,3) from the train set, and test MNIST characters (0,1,2,3) from the test set; (establishing a baseline)

Stage_2: Train MNIST characters (4,5,6) and test MNIST characters (0,1,2,3,4,5,6); (with the hope that the model will memorize the results of Stage_1)

Stage_3: Train MNIST characters (7,8,9) and test MNIST characters (0,1,2,3,4,5,6,7,8,9); (with the hope that the model will memorize the results of Stage_1 and Stage_2)

I have tried two approaches:

Method_1: The original image (28x28) is resized to (7x7) and then flattened to (49). The KAN size is (49, 10, 10) with grid =3 and k=3. The training process is the same as Tutorial Example 7. Ref

Method_2: The original image (28x28) is mapped to (64) by nn.Linear(28*28,64) under PyTorch. The KAN size is (64, 16, 10) with grid=3 and k=3. The training process is the same as PyTorch Training. Ref

Roughly speaking both methods can achieve train accuracy > 90% in all three stages, but test accuracy degrades as (Stage 1 > 90%, Stage 2 ~ 40%, Stage 3 ~ 20%). I have also tried to change the grid size up to 100, but no significant improvement on test accuracy.

I am wondering if any width/grid/k setting under a memory/computation budget could reach a better accuracy of continual learning on MNIST.

Bytheway, some implementations of conv KAN as layer-drop-in-replacement don't provide the setting of 'bias_trainable=False, sp_trainable=False, sb_trainable=False' which limits the study of continual learning, e.g., on CIFAR-10.

Yes, I get the same results as yours. KAN seems to achieve continue learning only on some simple tasks, such as 1-D data fitting and 2-D scatter classification. And there are many limitations such as grid size and the number of layers. I find that it achieve continue learning on 2-D scatter classification with large grid size (at least 50?) and no intermediate hidden layers. But if you add the intermediate hidden layers or use the smaller grid size, the model forgets very fast in subsequent tasks. Maybe I need to try more combination of hyper-parameters.

rafaelcp commented 1 month ago

As I hypothesized in the efficient-kan repo, it seems KAN cannot do continual learning in more than 1 dimension if the output depends on more than 1 of them, as it cannot isolate ranges in groups of values the same way it does over single values. I did this experiment to show it: image Leftmost image: dataset composed of 10000 pixels (100x100). The output depends on X and Y, jointly. Other images: model prediction after training on each of the 5 rows, starting from the bottom one. Each row is composed by 2000 pixels (20x100). Notice how it generalizes the blob to the entire columns after each task, but erases it on the next task.

However, this is a [2,1] KAN without any hidden layers, and it turned out it couldn't learn it even in batch mode. So I tried a [2,5,1] KAN, which learned it in batch mode to a reasonable degree. Unfortunately, no success with continual learning: image

I'm using SGD with all biases turned off (Adam can mess things up in continual learning due to running statistics and momentum). Also, I'm using FastKAN.

ASCIIJK commented 1 month ago

As I hypothesized in the efficient-kan repo, it seems KAN cannot do continual learning in more than 1 dimension if the output depends on more than 1 of them, as it cannot isolate ranges in groups of values the same way it does over single values. I did this experiment to show it: image Leftmost image: dataset composed of 10000 pixels (100x100). The output depends on X and Y, jointly. Other images: model prediction after training on each of the 5 rows, starting from the bottom one. Each row is composed by 2000 pixels (20x100). Notice how it generalizes the blob to the entire columns after each task, but erases it on the next task.

However, this is a [2,1] KAN without any hidden layers, and it turned out it couldn't learn it even in batch mode. So I tried a [2,5,1] KAN, which learned it in batch mode to a reasonable degree. Unfortunately, no success with continual learning: image

I'm using SGD with all biases turned off (Adam can mess things up in continual learning due to running statistics and momentum). Also, I'm using FastKAN.

I have also found this issue. KAN with hidden layer cannot achieve continue learning. And I have made the experiments on 2-D scatter classification. It constructs 25 Gaussian distributions with different means. And model (KAN(width=[2, 25], grid=50)) learns 5 kinds of 2D Gaussian distributions at each task. The results shows that KAN with more than 50 grids performs well in continue learning. I reskon that KAN with more grids reduce the importance of each activative function to avoid the key function from rewriting. And I add just one hidden layer into the model (KAN(width=[2, 25, 25], grid=50)). It performs catastrophic forgetting. Therefore, this robustness in one-layer KAN is treated as achieving continue learning. Actually, this coincidence may be very fragile in multilayer KAN. In another view, KAN with more grids has much more parameters. It seems to using a large model to fit the small dataset, which a large number of parameters are redundant, thereby maintaining a easy decision boundary for old tasks. In the end, it seems that most of efficient KAN library could not achieve continue learning, even on 1D data fitting task.

fangkuoyu commented 1 month ago

@ASCIIJK @rafaelcp The paper of KAN describes continual learning as follows:

KANs have local plasticity and can avoid catastrophic forgetting by leveraging the locality of splines. The idea is simple: since spline bases are local, a sample will only affect a few nearby spline coefficients, leaving far-away coefficients intact (which is desirable since faraway regions may have already stored information that we want to preserve).

I think that the above statements depend on the distribution of input data. In the case of modeling peaks in 1D, the distribution of peaks is sparse so that the locality will stay true. But, in the case of modeling peaks in 2D or MNIST in 2D, the distribution of target features is dense in the space so modeling on new data will affect modeling on old data.

The paper of KAN also says that:

Here we simply present our preliminary results on an extremely simple example, to demonstrate how one could possibly leverage locality in KANs (thanks to spline parametrizations) to reduce catastrophic forgetting. However, it remains unclear whether our method can generalize to more realistic setups, especially in high-dimensional cases where it is unclear how to define “locality”.

Based on our experiments, I think that KAN for continual learning holds on special domains, but not for general purposes.

rafaelcp commented 1 month ago

The 2D Gaussians domain is also sparse, so not a matter of sparse x dense. It is a matter of dependency between variables (which, unfortunately, is the norm).

KindXiaoming commented 3 weeks ago

Hi, just want to draw your attention to this paper which seems quite relevant: Distal Interference: Exploring the Limits of Model-Based Continual Learning