KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
13.56k stars 1.19k forks source link

How to support complex data type? #87

Open yt7589 opened 1 month ago

yt7589 commented 1 month ago

Hi. It's a great project. I want to use KAN in radar signal processing domain. As you know that the radar signal is complex number. When I create a dataset with complex data and try to train KAN it report errors as below:

Traceback (most recent call last):
  File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 39, in <module>
    main(args=args)
  File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 26, in main
    FmcwApp.startup(args=args)
  File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 23, in startup
    model.train(dataset, opt="LBFGS", steps=20, lamb=0.001, lamb_entropy=2.)
  File "/home/yantao/awork/pykan/kan/KAN.py", line 896, in train
    self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
  File "/home/yantao/awork/pykan/kan/KAN.py", line 241, in update_grid_from_samples
    self.forward(x)
  File "/home/yantao/awork/pykan/kan/KAN.py", line 309, in forward
    x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
  File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yantao/awork/pykan/kan/KANLayer.py", line 173, in forward
    y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device)  # shape (size, batch)
  File "/home/yantao/awork/pykan/kan/spline.py", line 100, in coef2curve
    y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
  File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
    B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
  File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
    B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
  File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
    B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
  File "/home/yantao/awork/pykan/kan/spline.py", line 57, in B_batch
    value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
RuntimeError: "ge_cpu" not implemented for 'ComplexDouble'

How to solve this problem? I prepare the development environment by using pip install -r requirements. I notice that the pytorch is CPUversion. What about I switch to GPU version Pytorch? Can it solve this problem?

KindXiaoming commented 1 month ago

Hi, unfortunately pykan doesn't support complex numbers. However, you may try treating real and imaginary parts separately, so you end up feeding KAN with vectors of doubled length.

pykan doesn't support GPU out-of-the-box, so I'd suggest debugging (with small-scale datasets) on cpus first.

KindXiaoming commented 1 month ago

Is there any particular reason in your domain that you don't want to separate real and imaginary parts? I know in some applications people want to constrain to holomorphic functions where complex neural networks are favored. In other cases, I don't see a strong reason not to just treat real and imaginary separately and feed into real-valued neural networks (which are much more optimized than complex-valued ones).