KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.57k stars 1.34k forks source link

AttributeError: 'float' object has no attribute 'to' #329

Closed Sleevexiu closed 2 months ago

Sleevexiu commented 2 months ago

When I running this code in hellokan.py:

model.prune() model.plot(mask=True)

It goes wrong

AttributeError Traceback (most recent call last) Cell In[18], line 1 ----> 1 model.prune() 2 model.plot(mask=True)

File e:\pykan-master\kan\MultKAN.py:956, in MultKAN.prune(self, node_th, edge_th) 955 def prune(self, node_th=1e-2, edge_th=3e-2): --> 956 self = self.prune_node(node_th, log_history=False) 957 #self.prune_node(node_th, log_history=False) 958 self.forward(self.cache_data)

File e:\pykan-master\kan\MultKAN.py:928, in MultKAN.prune_node(self, threshold, mode, active_neurons_id, log_history) 924 model2.symbolic_fun[i].out_dim_mult = num_mult 926 width_new.append([num_sum, num_mult]) --> 928 model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) 929 model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) 931 model2.cache_data = self.cache_data

File e:\pykan-master\kan\KANLayer.py:305, in KANLayer.get_subset(self, in_id, out_id) 283 def get_subset(self, in_id, out_id): 284 ''' 285 get a smaller KANLayer from a larger KANLayer (used for pruning) 286
(...) ... 133 self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim, device=device) scale_base mask).requiresgrad(sb_trainable) # make scale trainable 134 #else: 135 #self.scale_base = torch.nn.Parameter(scale_base.to(device)).requiresgrad(sb_trainable)

AttributeError: 'float' object has no attribute 'to'

KindXiaoming commented 2 months ago

sorry haven't updated tutorials, but the backend codes were updated. Please try

model = model.prune()
model.plot()
geliuxin commented 2 months ago

It can't work. The same error was reported.

Sleevexiu commented 2 months ago

Yes, it still doesn't work, and I have another question, why is there a significant difference in the image after I run this code:

model(dataset['train_input']); model.plot(beta=100). 2

KindXiaoming commented 2 months ago

Link: https://github.com/KindXiaoming/pykan/issues/331