KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.59k stars 1.34k forks source link

HelloKAN error: train_loss: nan | test_loss: nan | reg: nan #379

Open lexmar07 opened 1 month ago

lexmar07 commented 1 month ago

Good afternoon!

HelloKAN does not work as intended:

1) with version 0.2.3 TypeError: MultKAN.recover_save_act_in_fit() missing 1 required positional argument: 'old_save_act' mentioned above mentioned here #375 and here #372

2) with version 0.2.2: `train_loss: nan | test_loss: nan | reg: nan resulting:

ValueError                                Traceback (most recent call last)
[<ipython-input-14-7a27cf870c39>](https://localhost:8080/#) in <cell line: 3>()
      9     # automatic mode
     10     lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
---> 11     model.auto_symbolic(lib=lib)

10 frames
[/content/pykan/kan/MultKAN.py](https://localhost:8080/#) in auto_symbolic(self, a_range, b_range, lib, verbose)
   1377                         print(f'fixing ({l},{i},{j}) with 0')
   1378                     else:
-> 1379                         name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
   1380                         self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
   1381                         if verbose >= 1:

[/content/pykan/kan/MultKAN.py](https://localhost:8080/#) in suggest_symbolic(self, l, i, j, a_range, b_range, lib, topk, verbose, r2_loss_fun, c_loss_fun, weight_simple)
   1307         # getting r2 and complexities
   1308         for (name, content) in symbolic_lib.items():
-> 1309             r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False)
   1310             if r2 == -1e8: # zero function
   1311                 r2s.append(-1e8)

[/content/pykan/kan/MultKAN.py](https://localhost:8080/#) in fix_symbolic(self, l, i, j, fun_name, fit_params_bool, a_range, b_range, verbose, random, log_history)
    485             y = self.spline_postacts[l][:, j, i]
    486             #y = self.postacts[l][:, j, i]
--> 487             r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range, verbose=verbose)
    488             if mask[i,j] == 0:
    489                 r2 = - 1e8

[/content/pykan/kan/Symbolic_KANLayer.py](https://localhost:8080/#) in fix_symbolic(self, i, j, fun_name, x, y, random, a_range, b_range, verbose)
    229             else:
    230                 #initialize from x & y and fun
--> 231                 params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device)
    232                 self.funs[j][i] = fun
    233                 self.funs_avoid_singularity[j][i] = fun_avoid_singularity

[/content/pykan/kan/utils.py](https://localhost:8080/#) in fit_params(x, y, fun, a_range, b_range, grid_number, iteration, verbose, device)
    259 
    260     post_fun = torch.nan_to_num(post_fun)
--> 261     reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy())
    262     c_best = torch.from_numpy(reg.coef_)[0].to(device)
    263     d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)

[/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_base.py](https://localhost:8080/#) in fit(self, X, y, sample_weight)
    682         accept_sparse = False if self.positive else ["csr", "csc", "coo"]
    683 
--> 684         X, y = self._validate_data(
    685             X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True
    686         )

[/usr/local/lib/python3.10/dist-packages/sklearn/base.py](https://localhost:8080/#) in _validate_data(self, X, y, reset, validate_separately, **check_params)
    594                 y = check_array(y, input_name="y", **check_y_params)
    595             else:
--> 596                 X, y = check_X_y(X, y, **check_params)
    597             out = X, y
    598 

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)
   1088     )
   1089 
-> 1090     y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric, estimator=estimator)
   1091 
   1092     check_consistent_length(X, y)

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in _check_y(y, multi_output, y_numeric, estimator)
   1098     """Isolated part of check_X_y dedicated to y validation"""
   1099     if multi_output:
-> 1100         y = check_array(
   1101             y,
   1102             accept_sparse="csr",

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
    897 
    898         if force_all_finite:
--> 899             _assert_all_finite(
    900                 array,
    901                 input_name=input_name,

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in _assert_all_finite(X, allow_nan, msg_dtype, estimator_name, input_name)
    144                     "#estimators-that-handle-nan-values"
    145                 )
--> 146             raise ValueError(msg_err)
    147 
    148     # for object dtype data, we only check for NaNs (GH-13254)

ValueError: Input y contains NaN.

What is funny is that 0.2.2 used to work just a few days ago. I can hardly understand why this happens. Here is the Google- colab replication code: https://colab.research.google.com/drive/1YOU7AifdYieMWK2hDfKjlN7l6_n6BkvV?usp=sharing

lexmar07 commented 1 month ago

OK, actually, one can add something like

model = model.prune(node_th = 1e-1)

to avoid hanging vertices in version 0.2.2

Danuzco commented 1 month ago

The same issue here.

/usr/local/lib/python3.10/dist-packages/kan/MultKAN.py in fit(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, displaymetrics) 938 if == steps-1 and old_save_act: 939 #self.save_act = True --> 940 self.recover_save_act_in_fit() 941 942 train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)

TypeError: MultKAN.recover_save_act_in_fit() missing 1 required positional argument: 'old_save_act'

It was working properly a couple of days ago.

carbonox-infernox commented 1 month ago

I'm having the same issue. The problem goes away when I use save_act=False to override the default (True) as a kwarg for KAN. I only just started playing with this stuff today, so I have no idea whether that's a dumb thing to do.

EricCJoyce commented 1 month ago

Problem 1 goes away in version 0.2.3 by changing MultKAN.py line 940 to self.recover_save_act_in_fit(True)

Problem 2 persists in version 0.2.3. Training seems to proceed, but inevitably all losses turn to nans.

chen-erqi commented 1 month ago

I find the same issue! The multiKan is easy to nan.

lucas15936 commented 1 month ago

Has anyone managed to fix problem 2?

EricCJoyce commented 1 month ago

Has anyone managed to fix problem 2?

I've rolled back to the July 22 version of PyKAN, and, so far, training proceeds with no nans.

However, I still don't know what is causing problem 2 in later versions of PyKAN.

ZijuanXin commented 1 month ago

I I'm having the same issue! Has anyone managed to fix problem 2?

lucas15936 commented 1 month ago

Has anyone managed to fix problem 2?

I've rolled back to the July 22 version of PyKAN, and, so far, training proceeds with no nans.

However, I still don't know what is causing problem 2 in later versions of PyKAN.

How do you use the July 22 version? You just downloaded the folder and is running like a function normally ?

lucas15936 commented 1 month ago

Bug fixed in latest version (at least in HelloKAN).