KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.94k stars 1.38k forks source link

initialize_from_another_model ERROR #275

Closed donupup closed 3 months ago

donupup commented 4 months ago

做实验的过程中遇到下面这个问题,从小的grid初始化另一个较大的grid的KAN网络时,我并没有修改数据集的内容以及形状,但是initialize失败。我尝试将输入维度改为10,就可以成功的初始化,这是因为输入维度过高的原因吗?这个问题应该怎么解决? 代码如下:

for i in range(grids.shape[0]):
    if i == 0:
        print("running model with grid " + str(grids[i]))
        model = KAN(width=[100, 10,  1], grid=grids[i], k=k, device=device)
        model(kan_dataset['train_input'])
    if i != 0:
        model = KAN(width=[100, 10,  1], grid=grids[i], k=k, device=device).initialize_from_another_model(model,kan_dataset['train_input'])
    results = model.train(kan_dataset, opt="LBFGS", steps=steps, lamb=1e-3, stop_grid_update_step=30, device=device)

报错如下: image

kaneyxx commented 4 months ago

This error occurred when the model architecture was changed. I think you pruned or removed any edge or node to create a smaller network. If so, you need to initialize the new model the same way as your latest model architecture.

KindXiaoming commented 3 months ago

In the newest update, grid refinement can be done more easily model = model.refine(new_grid). Tutorial: https://github.com/KindXiaoming/pykan/blob/master/tutorials/Example_1_function_fitting.ipynb. Please let me know if your problem persists.

donupup commented 3 months ago

This error occurred when the model architecture was changed. I think you pruned or removed any edge or node to create a smaller network. If so, you need to initialize the new model the same way as your latest model architecture.

thanks for your reply, the problem was solved

donupup commented 3 months ago

In the newest update, grid refinement can be done more easily model = model.refine(new_grid). Tutorial: https://github.com/KindXiaoming/pykan/blob/master/tutorials/Example_1_function_fitting.ipynb. Please let me know if your problem persists.

The problem was solved, and as @kaneyxx said, I did the pruning in subsequent code, and this leads to an error