KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
15.12k stars 1.4k forks source link

An error is reported when the model is loaded #452

Open Saber-xxf opened 2 months ago

Saber-xxf commented 2 months ago

image ConstructorError Traceback (most recent call last) Cell In[33], line 1 ----> 1 model = model.rewind('0.11') 3 model.plot(scale=1)

File D:\Anconda\A\envs\pytorch\lib\site-packages\kan\MultKAN.py:647, in MultKAN.rewind(self, model_id) 643 self.saveckpt(path=self.ckpt_path+'/'+f'{self.round}.{self.state_id}') 645 print('rewind to model version '+f'{self.round-1}.{self.state_id}'+', renamed as '+f'{self.round}.{self.state_id}') --> 647 return MultKAN.loadckpt(path=self.ckpt_path+'/'+str(model_id))

File D:\Anconda\A\envs\pytorch\lib\site-packages\kan\MultKAN.py:554, in MultKAN.loadckpt(path) 534 ''' 535 load checkpoint from path 536 (...) 551 >>> KAN.loadckpt('./mark') 552 ''' 553 with open(f'{path}_config.yml', 'r') as stream: --> 554 config = yaml.safe_load(stream) 556 state = torch.load(f'{path}_state') 558 model_load = MultKAN(width=config['width'], 559 grid=config['grid'], 560 k=config['k'], (...) 573 round = config['round']+1, 574 device = config['device'])

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml__init__.py:125, in safe_load(stream) 117 def safe_load(stream): 118 """ 119 Parse the first YAML document in a stream 120 and produce the corresponding Python object. (...) 123 to be safe for untrusted input. 124 """ --> 125 return load(stream, SafeLoader)

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml__init__.py:81, in load(stream, Loader) 79 loader = Loader(stream) 80 try: ---> 81 return loader.get_single_data() 82 finally: 83 loader.dispose()

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:51, in BaseConstructor.get_single_data(self) 49 node = self.get_single_node() 50 if node is not None: ---> 51 return self.construct_document(node) 52 return None

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:60, in BaseConstructor.construct_document(self, node) 58 self.state_generators = [] 59 for generator in state_generators: ---> 60 for dummy in generator: 61 pass 62 self.constructed_objects = {}

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:413, in SafeConstructor.construct_yaml_map(self, node) 411 data = {} 412 yield data --> 413 value = self.construct_mapping(node) 414 data.update(value)

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:218, in SafeConstructor.construct_mapping(self, node, deep) 216 if isinstance(node, MappingNode): 217 self.flatten_mapping(node) --> 218 return super().construct_mapping(node, deep=deep)

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:143, in BaseConstructor.construct_mapping(self, node, deep) 140 if not isinstance(key, collections.abc.Hashable): 141 raise ConstructorError("while constructing a mapping", node.start_mark, 142 "found unhashable key", key_node.start_mark) --> 143 value = self.construct_object(value_node, deep=deep) 144 mapping[key] = value 145 return mapping

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:100, in BaseConstructor.construct_object(self, node, deep) 98 constructor = self.class.construct_mapping 99 if tag_suffix is None: --> 100 data = constructor(self, node) 101 else: 102 data = constructor(self, tag_suffix, node)

File D:\Anconda\A\envs\pytorch\lib\site-packages\yaml\constructor.py:427, in SafeConstructor.construct_undefined(self, node) 426 def construct_undefined(self, node): --> 427 raise ConstructorError(None, None, 428 "could not determine a constructor for the tag %r" % node.tag, 429 node.start_mark)

ConstructorError: could not determine a constructor for the tag 'tag:yaml.org,2002:python/object/apply:numpy.core.multiarray.scalar' in "./model/0.11_config.yml", line 6, column 7

NRodion commented 2 months ago

I have the same error while trying to load a model.

NRodion commented 2 months ago

The workaround is to change config = yaml.safe_load(stream) in MultKAN.loadckpt to config = yaml.unsafe_load(stream)