zhangxjohn / Reversible-Instance-Normalization

Implementation of RevIN is based on TF2.Keras and PyTorch.
Apache License 2.0
26 stars 2 forks source link

Doesn´t work in Tensorflow #1

Open sagagk opened 1 year ago

sagagk commented 1 year ago

I call your function from:


from RevIN import RevIN

%load_ext autoreload
%autoreload 2

revin_layer = RevIN(2)

x=Input(shape=(12,2))
model=revin_layer(x,mode="norm")

model2=LSTM(32,return_sequences=False)(model)
output_layer=Dense(2)(model2)
output_layer1=revin_layer(output_layer,mode="denorm")
model1 = Model(inputs=x, outputs=output_layer1)

model1.summary()

But I obtain error.. I appears to be in the backpropagation phase. The output:

Epoch 1/1000
Tensor("model/rev_in/StopGradient_1:0", shape=(None, 1, 2), dtype=float32)
Tensor("model/rev_in/StopGradient_1:0", shape=(None, 1, 2), dtype=float32)
Tensor("model/rev_in/StopGradient_1:0", shape=(None, 1, 2), dtype=float32)
Tensor("model/rev_in/StopGradient_1:0", shape=(None, 1, 2), dtype=float32)
Tensor("StopGradient_1:0", shape=(None, 1, 2), dtype=float32)

---------------------------------------------------------------------------

InaccessibleTensorError                   Traceback (most recent call last)

[<ipython-input-64-d0481f2f6b0e>](https://localhost:8080/#) in <cell line: 1>()
----> 1 history = model1.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=1000, callbacks=[cp,es], verbose=2)

1 frames

[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in __exit__(self, typ, value, traceback)
    140         if typ is None:
    141             try:
--> 142                 next(self.gen)
    143             except StopIteration:
    144                 return False

InaccessibleTensorError: <tf.Tensor 'StopGradient_1:0' shape=(None, 1, 2) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'StopGradient_1:0' shape=(None, 1, 2) dtype=float32> was defined here:
    File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
(.........)
      return wrapped_call(*new_args, **new_kwargs)
    File "/usr/local/lib/python3.10/dist-packages/keras/saving/legacy/saved_model/save_impl.py", line 698, in call_and_return_conditional_losses
      call_output = layer_call(*args, **kwargs)
    File "/content/RevIN.py", line 34, in call
      self._get_statistics(inputs)
    File "/content/RevIN.py", line 45, in _get_statistics
      self.stdev = K.stop_gradient(K.sqrt(K.var(x, axis=dim2reduce, keepdims=True) + self.eps))
    File "/usr/local/lib/python3.10/dist-packages/keras/backend.py", line 4716, in stop_gradient
      return tf.stop_gradient(variables)

The tensor <tf.Tensor 'StopGradient_1:0' shape=(None, 1, 2) dtype=float32> cannot be accessed from FuncGraph(name=model_layer_call_and_return_conditional_losses, id=133472755512848), because it was defined in FuncGraph(name=rev_in_layer_call_and_return_conditional_losses, id=133472753537696), which is out of scope.

I see that your simple demo works, but when I train a neural network, appears this "out of scope" error.

zhangxjohn commented 1 year ago

Here is my test code and procedure:

    import numpy as np
    from tensorflow.keras.layers import Input, LSTM, Dense
    from tensorflow.keras.models import  Model

    revin_layer = RevIN()

    x=Input(shape=(12, 2))
    model=revin_layer(x,mode="norm")

    model2=LSTM(32, return_sequences=False)(model)
    output_layer=Dense(2)(model2)
    output_layer1=revin_layer(output_layer,mode="denorm")
    model1 = Model(inputs=x, outputs=output_layer1)
    model1.summary()

    model1.compile(optimizer='Adam', loss='mse')

    x = np.random.randn(16, 12, 2)
    y = np.random.randn(16, 2)
    model1.fit(x=x, y=y, epochs=10, batch_size=1)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 12, 2)]      0                                            
__________________________________________________________________________________________________
rev_in (RevIN)                  multiple             4           input_1[0][0]                    
                                                                 dense[0][0]                      
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 32)           4480        rev_in[0][0]                     
__________________________________________________________________________________________________
dense (Dense)                   (None, 2)            66          lstm[0][0]                       
==================================================================================================
Total params: 4,550
Trainable params: 4,550
Non-trainable params: 0
__________________________________________________________________________________________________
Epoch 1/10
16/16 [==============================] - 0s 2ms/step - loss: 0.9344
Epoch 2/10
16/16 [==============================] - 0s 2ms/step - loss: 0.9181
Epoch 3/10
16/16 [==============================] - 0s 2ms/step - loss: 0.9076
Epoch 4/10
16/16 [==============================] - 0s 2ms/step - loss: 0.9027
Epoch 5/10
16/16 [==============================] - 0s 2ms/step - loss: 0.8949
Epoch 6/10
16/16 [==============================] - 0s 2ms/step - loss: 0.8884
Epoch 7/10
16/16 [==============================] - 0s 2ms/step - loss: 0.8836
Epoch 8/10
16/16 [==============================] - 0s 2ms/step - loss: 0.8786
Epoch 9/10
16/16 [==============================] - 0s 2ms/step - loss: 0.8713
Epoch 10/10
16/16 [==============================] - 0s 3ms/step - loss: 0.8663

This is my part of packages:

tensorflow==2.3.0
numpy==1.19.5

Maybe you can check your environment. This is my suggestion currently.