KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
13.99k stars 1.27k forks source link

Pruning doesn't work for me #240

Closed FractalySyn closed 3 weeks ago

FractalySyn commented 2 months ago

I played with KANs with both versions 0.04 and 0.0.5 and couldn't manage to prune the KAN. The code

model.prune()
model.plot(mask=True)

and

model2(dataset['train_input'])
model2.plot()

would output the same original network as

model.plot()

I ran those snippets after training a KAN on a division symbolic DGP following the tutorials. I made sure to run them in a dedicated venv with all the (version-wise) required packages properly installed.

This issue does not arise in version 0.0.2


I take the opportunity to ask an unrelated question. Is the code used for the paper available? I would notably like to work on the simple symbolic tasks (division, etc.) and the Feynman dataset examples but I didn't manage to reproduce them so far...

KindXiaoming commented 2 months ago

hi, could you please provide plots as well?

The original code to train the network and do the plots in paper is quite messy. The code I released is a significantly cleaned up version but I haven't rerun those experiments with the new code yet, so I can't guarantee you will get the exactly same results (e.g., initialization and grid extension have become a bit different). I have to be honest that rerunning the experiments with the new code is not on the top of my list, so I'd encourage you to run baselines yourself. But if the results you got differ a lot from what's been reported in the paper, either much better or much worse, please open another issue and I'm happy to discuss.

FractalySyn commented 2 months ago

Silly me I didn't back up the code I used in the beginning and now it seems to work as expected. I might have misinterpreted what I've seen. Thinking about it now it could have been related to the seeding issue. Anyways, if this happens again I'll let you know and upload the plots.


Ok thanks. I'll play some more with the released package. Best

sumitagarai commented 2 weeks ago

The prunning function doesn't work for me either. Can you please look into it?

Screenshot 2024-07-17 at 5 13 41 PM
KindXiaoming commented 2 weeks ago

@sumitagarai you may need to tweak the threshold in model = model.prune(node_th=5e-2), or prune it by hand model = model.prune(active_neurons_id=[[1]])