I found some issues and some possible improvements in the MultKAN library.
The MultKAN modulus seems not to be capable to deal with "int values" of mult_arity when multiplication nodes exists. I fixed it by pre-processing the mult_arity vector as follows:
if multiplication_bool:
if isinstance(mult_arity, int):
mult_arity_list - [[]]
for i in range(len(width)-2):
value = []
for j in range(width[i+1][1]):
value.append(mult_arity)
mult_arity_list.append(value)
mult_arity_list.append([])
mult_arity = mult_arity_list
This could be implemented directly into the KAN molulus.
Prune() not working properly with MultKAN. Fixed line 1950 as follows:
subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].unsqueeze(1).expand(out_dim, mult_arity[i])], dim=1) # Added by Francesco Porta
Automatically saving in desired format the tree plot in the function model.plot().
Added a line code that saves the model in a desired folder (line 1295):
plt.savefig(f"{folder}/{imageName}.svg", bbox_inches="tight", dpi=400) # Added by Francesco Porta
plt.close()
function get_act() not handling different devices. Solved by modifying line 2674:
self.forward(x.to(self.device))
Hope this can help someone having the same issues.
Good morning,
I found some issues and some possible improvements in the MultKAN library.
The MultKAN modulus seems not to be capable to deal with "int values" of mult_arity when multiplication nodes exists. I fixed it by pre-processing the mult_arity vector as follows:
if multiplication_bool:
This could be implemented directly into the KAN molulus.
Prune() not working properly with MultKAN. Fixed line 1950 as follows: subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].unsqueeze(1).expand(out_dim, mult_arity[i])], dim=1) # Added by Francesco Porta
Automatically saving in desired format the tree plot in the function model.plot(). Added a line code that saves the model in a desired folder (line 1295): plt.savefig(f"{folder}/{imageName}.svg", bbox_inches="tight", dpi=400) # Added by Francesco Porta plt.close()
function get_act() not handling different devices. Solved by modifying line 2674: self.forward(x.to(self.device))
Hope this can help someone having the same issues.
francesco.porta99@gmail.com