Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

Train Multi-KANs simultaneously! Coding realization #30

Closed dwtum closed 1 month ago

dwtum commented 1 month ago

Hey, guys. Now I have this problem, the structure is like following:

y_pred = KAN_1(X1)f1(X1) + KAN_2(X2)f2(X2) + KAN_3(X3)*f3(X3),f1,f2,f3 is all fixed and known functions, how can I train KAN_i,i=1,2,3 simultaneously?

I am wondering how can I realize this training process.

UnbearableFate commented 1 month ago

maybe you can use torch.distributed to train your different models in the multiprocess, and allreduce the results

Blealtan commented 1 month ago

It should have nothing to do with distributed parallel. Simply compute your loss on y_pred and do loss.backward(); optimizer.step(). Autograd should handle that automatically.