KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.3k stars 1.3k forks source link

Copying KAN models #285

Open Arjeus opened 2 months ago

Arjeus commented 2 months ago

How can I copy a KAN Model? I am trying to copy a pruned model, but is encountering errors. I cannot explicitly rebuild the model as it is automatically pruned in order to utilize loading of the weights using state_dict. Thanks.

Here is an example of the error:

RuntimeError Traceback (most recent call last) Cell In[8], line 1 ----> 1 model_bk = deepcopy(model)

File /usr/lib/python3.10/copy.py:172, in deepcopy(x, memo, _nil) 170 y = x 171 else: --> 172 y = _reconstruct(x, memo, *rv) 174 # If is its own copy, don't memoize. 175 if y is not x:

File /usr/lib/python3.10/copy.py:271, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 269 if state is not None: 270 if deep: --> 271 state = deepcopy(state, memo) 272 if hasattr(y, 'setstate'): 273 y.setstate(state)

File /usr/lib/python3.10/copy.py:146, in deepcopy(x, memo, _nil) 144 copier = _deepcopy_dispatch.get(cls) 145 if copier is not None: --> 146 y = copier(x, memo) 147 else: 148 if issubclass(cls, type):

File /usr/lib/python3.10/copy.py:231, in _deepcopy_dict(x, memo, deepcopy) 229 memo[id(x)] = y 230 for key, value in x.items(): --> 231 y[deepcopy(key, memo)] = deepcopy(value, memo) 232 return y

File /usr/lib/python3.10/copy.py:146, in deepcopy(x, memo, _nil) 144 copier = _deepcopy_dispatch.get(cls) 145 if copier is not None: --> 146 y = copier(x, memo) 147 else: 148 if issubclass(cls, type):

File /usr/lib/python3.10/copy.py:206, in _deepcopy_list(x, memo, deepcopy) 204 append = y.append 205 for a in x: --> 206 append(deepcopy(a, memo)) 207 return y

File /usr/lib/python3.10/copy.py:153, in deepcopy(x, memo, _nil) 151 copier = getattr(x, "deepcopy", None) 152 if copier is not None: --> 153 y = copier(memo) 154 else: 155 reductor = dispatch_table.get(cls)

File ~/code/pykan-env/lib/python3.10/site-packages/torch/_tensor.py:86, in Tensor.deepcopy(self, memo) 84 return handle_torch_function(Tensor.deepcopy, (self,), self, memo) 85 if not self.is_leaf: ---> 86 raise RuntimeError( 87 "Only Tensors created explicitly by the user " 88 "(graph leaves) support the deepcopy protocol at the moment. " 89 "If you were attempting to deepcopy a module, this may be because " 90 "of a torch.nn.utils.weight_norm usage, " 91 "see https://github.com/pytorch/pytorch/pull/103001" 92 ) 93 if id(self) in memo: 94 return memo[id(self)]

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

KindXiaoming commented 1 month ago

Yeah copy.deepcopy doesn't work. After updating to the most recent version, you can use

from kan.ckpt import *
path = 'model'
saveckpt(model, path)
model_copied = loadckpt(path)
lexmar07 commented 2 weeks ago

I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-169-6ba17f8ff70b>](https://localhost:8080/#) in <cell line: 3>()
      1 from kan.ckpt import *
      2 path ='/content/model'
----> 3 saveckpt(model, path)
      4 model_copied = loadckpt(path)
      5 

1 frames
[/content/pykan/kan/ckpt.py](https://localhost:8080/#) in saveckpt(model, path)
     18         sp_trainable = model.sp_trainable,
     19         sb_trainable = model.sb_trainable,
---> 20         device = model.device,
     21         state_id = model.state_id,
     22         auto_save = model.auto_save,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __getattr__(self, name)
   1686             if name in modules:
   1687                 return modules[name]
-> 1688         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
   1689 
   1690     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'MultKAN' object has no attribute 'device'