remigenet / TKAT

Temporal Kolmogorov-Arnold Transformer
Other
67 stars 12 forks source link

ValueError: Layer functional weight shape (1, 16, 10) is not compatible with provided weight shape (1, 10). #9

Open mw66 opened 2 months ago

mw66 commented 2 months ago

Encountered a strange error: in the middle of the training,

Epoch 54/1000                                                                                                                                                                                                      
98/98 ━━━━━━━━━━━━━━━━━━━━ 421s 4s/step - loss: 45.3344 - mae: 3.3803 - val_loss: 35.3193 - val_mae: 3.5939                                                                                                        
Epoch 55/1000                                                                                                                                                                                                      
98/98 ━━━━━━━━━━━━━━━━━━━━ 413s 4s/step - loss: 40.8876 - mae: 3.3554 - val_loss: 37.2913 - val_mae: 3.7625                                                                                                        
Traceback (most recent call last):                  
...
/lib/python3.11/site-packages/keras/src/layers/layer.py", line 703, in set_weights                                                                                                
    raise ValueError(                             
ValueError: Layer functional weight shape (1, 16, 10) is not compatible with provided weight shape (1, 10).

With the following model shape:

Model: "functional"                                                                                                                                                                                                
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓                                                                                                       
┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃                                                                                                       
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩                                                                                                       
│ input_layer (InputLayer)      │ (None, 252, 13)           │               0 │ -                          │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ embedding_layer               │ (None, 252, 1, 13)        │               0 │ input_layer[0][0]          │                                                                                                       
│ (EmbeddingLayer)              │                           │                 │                            │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ past_observed_and_known       │ (None, 251, 1, 13)        │               0 │ embedding_layer[0][0]      │                                                                                                       
│ (Lambda)                      │                           │                 │                            │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ vsn_past_features             │ (None, 251, 16)           │          13,002 │ past_observed_and_known[0… │                                                                                                       
│ (VariableSelectionNetwork)    │                           │                 │                            │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ future_known (Lambda)         │ (None, 1, 1, 13)          │               0 │ embedding_layer[0][0]      │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ encoder (RecurrentLayer)      │ [(None, 251, 16), (None,  │           4,400 │ vsn_past_features[0][0]    │                                                                                                       
│                               │ 16), (None, 16), (None,   │                 │                            │                                                                                                       
│                               │ 16)]                      │                 │                            │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ vsn_future_features           │ (None, 1, 16)             │          13,002 │ future_known[0][0]         │                                                                                                       
│ (VariableSelectionNetwork)    │                           │                 │                            │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ decoder (RecurrentLayer)      │ (None, 1, 16)             │           4,400 │ vsn_future_features[0][0], │                                                                                                       
│                               │                           │                 │ encoder[0][1],             │                                                                                                       
│                               │                           │                 │ encoder[0][2],             │                                                                                                       
│                               │                           │                 │ encoder[0][3]              │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ concatenate (Concatenate)     │ (None, 252, 16)           │               0 │ encoder[0][0],             │                                                                                                       
│                               │                           │                 │ decoder[0][0]              │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ gate_28 (Gate)                │ (None, 252, 16)           │               0 │ concatenate[0][0]          │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ concatenate_1 (Concatenate)   │ (None, 252, 16)           │               0 │ vsn_past_features[0][0],   │                                                                                                       
│                               │                           │                 │ vsn_future_features[0][0]  │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ add_and_norm_28 (AddAndNorm)  │ (None, 252, 16)           │              32 │ gate_28[0][0],             │                                                                                                       
│                               │                           │                 │ concatenate_1[0][0]        │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ grn_28 (GRN)                  │ (None, 252, 16)           │               0 │ add_and_norm_28[0][0]      │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ multi_head_attention          │ (None, 252, 16)           │           4,304 │ grn_28[0][0],              │                                                                                                       
│ (MultiHeadAttention)          │                           │                 │ grn_28[0][0], grn_28[0][0] │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ flatten (Flatten)             │ (None, 4032)              │               0 │ multi_head_attention[0][0] │                                                                                                       
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤                                                                                                       
│ dense_158 (Dense)             │ (None, 1)                 │           4,033 │ flatten[0][0]              │                                                                                                       
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘                                                                                                       
 Total params: 43,173 (168.64 KB)                                                                        
 Trainable params: 42,853 (167.39 KB)                                                                                                                                                                              
 Non-trainable params: 320 (1.25 KB)                                                                                                                                            
remigenet commented 1 month ago

I am not really sure what is the cause but I have a small guess, where you running in a distributed context / multiple process ? I solved an issue with serialization in the TKAN that was related to how the backend influence (it was working in jax but not torch or tensorflow) initializer in keras_efficient_kan ! Try install the latest tkan version (0.4.3) and try again, it may be as simple as that. If not working could you give me your python depencies setup and a reproductible example (with just random data generator) so I can find out !