KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.49k stars 1.33k forks source link

Train_loss and test_loss become NaN after using model.prune() #416

Open DonSteven opened 3 weeks ago

DonSteven commented 3 weeks ago

Hello, I'm working on a multi-classification model and the result looks good. But when I train the model after using model.prune(), train_loss and test_loss become NaN easily even if I set lr and step very small. Could anyone please to solve out this problem? Thanks a lot!

Here's my code. I include a scheduler: torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)

image image image image image

shrrr commented 3 weeks ago

I have the same problem. I also use KANs to do five classification tasks, but after I complete the first training stage and run model = model.prune(), I get nan in train loss and test loss in the second training stage. However, after I set update_grid=False in model.fit() function, the loss is not nan. I don't know why!!

The model I defined is:

model = KAN(width=[9,10,2], grid=3, k=3, seed=42, device=device)

And training code is:

results = model.fit(data, opt="LBFGS", metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss(), steps=50)

Any helpful information or ideas, please let me know! Thank you!

DonSteven commented 3 weeks ago

I have the same problem. I also use KANs to do five classification tasks, but after I complete the first training stage and run model = model.prune(), I get nan in train loss and test loss in the second training stage. However, after I set update_grid=False in model.fit() function, the loss is not nan. I don't know why!!

The model I defined is:

model = KAN(width=[9,10,2], grid=3, k=3, seed=42, device=device)

And training code is:

results = model.fit(data, opt="LBFGS", metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss(), steps=50)

Any helpful information or ideas, please let me know! Thank you!

Thanks for your reply! It really works!

andrewrgarcia commented 2 weeks ago

@DonSteven might not be the best fix but I had the same problem, so I have reverted to pykan==0.2.2 (where this issue is no longer present) until resolved in the most updated version

srigas commented 2 weeks ago

Based on my very short experience working with the new version of pykan, I have noticed that the issue arises whenever disconnected nodes remain at the KAN's computation graph after pruning. The plot provided in your post appears to fall under this case. My workaround was to define a custom prune function, as follows:

def custom_prune(model, node_th=1e-2, edge_th=3e-2):
    # First prune edges
    model.prune_edge(edge_th, log_history=False)
    model.forward(model.cache_data)
    model.attribute()
    model.log_history('prune')
    # Then prune nodes
    model = model.prune_node(node_th, log_history=False)

    return model

# Example usage
model = custom_prune(model)

Essentially, all I've done is reverse the pruning of edges and pruning of nodes, ensuring that edges are pruned before nodes and not vice-versa, as in the original prune() implementation. This appears to fix the issue of NaNs post-pruning, although I'm not sure if it affects the training in some other, unexpected way.

KindXiaoming commented 2 weeks ago

@srigas this is an interesting observation! it's be awesome if ppl can confirm this for their cases!

Ilex00para commented 1 week ago

I encountered the same problem and it worked with @srigas function and also when you change the order of node and edge pruning in the library function.

When using the standard .prune() function, it seems that the pruning changes are made in the mask (adding zeros) but not in the attribution score calculations (since they do not change when using the normal pruning function but the costume function). Also, if the .prune_edge() is not called no pruning is done at all even after calling .prune_node(). As I understood from the code the node pruning depends (only) on the edges and the functions are called in the wrong order?

(Also some things for the documentation prune_edge is not returning a new KAN its just updating the mask.)