remigenet / TKAN

TKAN: Temporal Kolmogorov-Arnold Networks
Other
161 stars 24 forks source link

The problem has not been completely solved. #15

Open ChineseDictionary opened 3 weeks ago

ChineseDictionary commented 3 weeks ago
          The problem has not been completely solved. When I try to load the model in another file in the same environment, the problem still occurs. How do we fix this? Thank you.

The save and load steps are the same as above. I have updated TKAN and keras_efficient_kan to the latest version. The error message is ValueError: A total of 2 objects could not be loaded. Example error message for object : The shape of the target variable and the shape of the target value in variable.assign(value) must match. variable.shape=(1, 10), Received: value.shape=(1, 10). Target variable: <KerasVariable shape=(1, 65, 10), dtype=float32, path=tkan/tkan_cell/kan_linear/grid> List of objects that could not be loaded: [\, \] Originally posted by @ChineseDictionary in https://github.com/remigenet/TKAN/issues/12#issuecomment-2363149916

ducanbk13 commented 6 days ago

Hi, I have the same problem. It shows something like: The shape of the target variable and the shape of the target value invariable.assign(value)` must match. variable.shape=(1, 10), Received: value.shape=(1, 10). Target variable: <KerasVariable shape=(1, 2, 10), dtype=float32, path=tkan_6/tkan_cell_6/kan_linear_6/grid>'

List of objects that could not be loaded: [, , , ] I tested in many ways from loading .h5 file to using default keras. All of them don't work at all. Has anyone fixed that?

remigenet commented 8 hours ago

Hi ! Back from holidays only now, sorry for the delay ! I found the origin of the issue, in facts it was linked to the GridInitialize from KANLinear, and was a tricky issue as the problem didn't occurs at all in jax (that's why I though it was working) but only in tensorflow and torch ! Fix involved just relying on numpy in the initializer instead of keras backend, as it seems either tile or another operation creates the saved weights to not have the same shape (I would guess something like tile in tensorflow and torch are just creating a form of view while jax really create the matrix of values, still unsure pure guess..) I have added tests in all backend for this now and pushed a new version on pypi, so please just update to 0.4.3 and tell me if it's working now !