KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
15.12k stars 1.4k forks source link

KAN 2.0: PLOT, PRUNE, MULT_ARITY - How to solve: #472

Open francescoporta opened 1 month ago

francescoporta commented 1 month ago

Good morning,

I found some issues and some possible improvements in the MultKAN library.

The MultKAN modulus seems not to be capable to deal with "int values" of mult_arity when multiplication nodes exists. I fixed it by pre-processing the mult_arity vector as follows:

if multiplication_bool:

if isinstance(mult_arity, int):
    mult_arity_list - [[]]

        for i in range(len(width)-2):
            value = []
            for j in range(width[i+1][1]):
                value.append(mult_arity)

            mult_arity_list.append(value)

        mult_arity_list.append([])    
        mult_arity = mult_arity_list

This could be implemented directly into the KAN molulus.

Prune() not working properly with MultKAN. Fixed line 1950 as follows: subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].unsqueeze(1).expand(out_dim, mult_arity[i])], dim=1) # Added by Francesco Porta

Automatically saving in desired format the tree plot in the function model.plot(). Added a line code that saves the model in a desired folder (line 1295): plt.savefig(f"{folder}/{imageName}.svg", bbox_inches="tight", dpi=400) # Added by Francesco Porta plt.close()

function get_act() not handling different devices. Solved by modifying line 2674: self.forward(x.to(self.device))

Hope this can help someone having the same issues.

francesco.porta99@gmail.com