csinva / imodelsX

Scikit-learn friendly library to interpret, and prompt-engineer text datasets using large language models.
https://csinva.io/imodelsX/
MIT License
149 stars 25 forks source link

double not float KAN #8

Open edmondja opened 4 months ago

edmondja commented 4 months ago

I have by running your demo notebook :

RuntimeError Traceback (most recent call last) Cell In[25], line 10 7 X, y = make_classification(n_samples=5000, n_features=5, n_informative=3) 8 model = imodelsx.KANClassifier(hidden_layer_size=64, device='cpu', 9 regularize_activation=1.0, regularize_entropy=1.0) ---> 10 model.fit(X, y) 11 y_pred = model.predict(X) 12 print('Test acc', accuracy_score(y, y_pred))

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_sklearn.py:80, in KAN.fit(self, X, y, batch_size, lr, weight_decay, gamma) 78 x = x.view(-1, num_features).to(self.device) 79 optimizer.zero_grad() ---> 80 output = self.model(x).squeeze() 81 loss = criterion(output, labs.to(self.device).squeeze()) 82 if isinstance(self, (KANGAMClassifier, KANGAMRegressor)):

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_modules.py:296, in KANModule.forward(self, x, update_grid) 294 if update_grid: 295 layer.update_grid(x) --> 296 x = layer(x) 297 return x

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_modules.py:173, in KANLinearModule.forward(self, x) 170 def forward(self, x: torch.Tensor): 171 assert x.dim() == 2 and x.size(1) == self.in_features --> 173 base_output = F.linear(self.base_activation(x), self.base_weight) 174 spline_output = F.linear( 175 self.b_splines(x).view(x.size(0), -1), 176 self.scaled_spline_weight.view(self.out_features, -1), 177 ) 178 return base_output + spline_output

RuntimeError: expected m1 and m2 to have the same dtype, but got: float != double