tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

ValueError: Tensor conversion requested dtype float64 for Tensor with dtype float32 #2792

Closed muriloasouza closed 1 year ago

muriloasouza commented 1 year ago

Trying to use Keras Tuner to optimize my ESN model hyperparameters:

def build_model(hp):
    model = Sequential(name='ESN_NE')
    model.add(InputLayer((timesteps, input_dim), name='input_layer'))
    model.add(ESN(units=hp.Int('units', min_value=100, max_value=400, step=50),
                  connectivity=hp.Float('connectivity', min_value=0.1, max_value=0.4, step=0.1),
                  spectral_radius=hp.Float('spectral_radius', min_value=0.5, max_value=0.9, step=0.1),
                  activation=hp.Choice('activation_esn', ['relu', 'tanh', 'sigmoid']),
                  name='esn_layer'))
    if hp.Boolean("dropout"):
        model.add(Dropout(rate=hp.Float('rate', min_value=0.2, max_value=0.5, step=0.1)))

    for k in range(hp.Int("num_layers", 1, 2)):
        model.add(Dense(units=hp.Int(f'units_{k}', min_value=24, max_value=120, step=24),
                        activation=hp.Choice('activation_dense', ['relu', 'tanh', 'sigmoid']),
                        name=f'{k}_dense'))
    model.add(Dense(units=24,
                    activation=hp.Choice('activation_out', ['relu', 'tanh', 'sigmoid']),
                    name='output_layer'))
    model.compile(optimizer='adam',
                  loss='mean_squared_error')
    return model

Getting the following error:

Traceback (most recent call last):
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\keras_tuner\engine\tuner.py", line 158, in _try_build
    model = self._build_hypermodel(hp)
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\keras_tuner\engine\tuner.py", line 146, in _build_hypermodel
    model = self.hypermodel.build(hp)
  File "<ipython-input-11-40dc25714941>", line 5, in build_model
    model.add(ESN(units=hp.Int('units', min_value=100, max_value=400, step=50),
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\tensorflow\python\training\tracking\base.py", line 629, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\tensorflow_addons\rnn\esn_cell.py", line 159, in build
    self.recurrent_kernel = self.add_weight(
  File "C:\Users\Muril\PycharmProjects\tf-gpu\venv\lib\site-packages\tensorflow_addons\rnn\esn_cell.py", line 151, in _esn_recurrent_initializer
    scaling_factor = tf.math.divide_no_nan(
ValueError: Tensor conversion requested dtype float64 for Tensor with dtype float32: <tf.Tensor: shape=(), dtype=float32, numpy=0.4704204>
Invalid model 0/5

I tries again to build 5 more times, but all fail. model.add(ESN(units=hp.Int('units', min_value=100, max_value=400, step=50) has units as Int but from the error messagem is something related to float64? What i am missing here?

bhack commented 1 year ago

Do you have a very minimal but standalone runnable gist to reproduce this?

muriloasouza commented 1 year ago

@bhack here:

https://gist.github.com/muriloasouza/6b928d762fbbd2f748c68bc99eaf8382

Just let me know the best way i can upload the csv file? 4mb file.

bhack commented 1 year ago

Are not just few lines of csv enough? In that case you could share with a Gitbub gist.

muriloasouza commented 1 year ago
DataHora,Carga,Seno,Cosseno,Segunda,Terça,Quarta,Quinta,Sexta,Sábado,Domingo
02/01/2017 00:00,0.7403215090534004,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 01:00,0.7218042830437358,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 02:00,0.6953087864546649,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 03:00,0.6762163581271778,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 04:00,0.6609037125126326,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 05:00,0.6502756740811575,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 06:00,0.6460838636778332,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 07:00,0.6029047724573929,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 08:00,0.606485961773619,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 09:00,0.6708915421118841,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 10:00,0.748644487208883,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 11:00,0.7971625205708657,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 12:00,0.8115746679590072,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 13:00,0.8123554232830152,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 14:00,0.800616593560609,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 15:00,0.8220881549805562,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 16:00,0.8370240411959524,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 17:00,0.8271795419322734,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 18:00,0.8007780188808035,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 19:00,0.732694002791285,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 20:00,0.7599427172225689,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 21:00,0.7864011585936508,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 22:00,0.7772003667766858,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
02/01/2017 23:00,0.8270022787980956,0.49999999999999994,0.8660254037844387,1,0,0,0,0,0,0
03/01/2017 00:00,0.8294164356549703,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 01:00,0.8078237394107332,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 02:00,0.777077426215885,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 03:00,0.7541306533053501,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 04:00,0.7392694041990816,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 05:00,0.7245129254299703,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 06:00,0.7136136127072056,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 07:00,0.6661957619341512,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 08:00,0.6631672408488687,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 09:00,0.7256720201725747,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 10:00,0.7960327314270567,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 11:00,0.8295777104971223,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 12:00,0.8405440988073739,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 13:00,0.8388564499401516,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 14:00,0.8255494509163441,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 15:00,0.8527288325895022,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 16:00,0.8677148091833274,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 17:00,0.8594355637089743,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 18:00,0.8302994031894114,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 19:00,0.7498500044276283,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 20:00,0.7796490959638468,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 21:00,0.8022335929868324,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 22:00,0.7983398231566243,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
03/01/2017 23:00,0.8503465394581473,0.49999999999999994,0.8660254037844387,0,1,0,0,0,0,0
04/01/2017 00:00,0.8490615510231562,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 01:00,0.8258883462780251,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 02:00,0.786465751293435,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 03:00,0.7569455708091678,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 04:00,0.7408715815381094,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 05:00,0.7327962150017181,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 06:00,0.722582781197105,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 07:00,0.6765964280432641,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 08:00,0.6688503450835832,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 09:00,0.7336440835327239,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 10:00,0.8070142427805892,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 11:00,0.8380112900139125,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 12:00,0.8582326422364637,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 13:00,0.8558907900790563,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 14:00,0.8294126360843945,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 15:00,0.861100753728339,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 16:00,0.884894417062603,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 17:00,0.8808726656571672,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 18:00,0.8418242155157362,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 19:00,0.7639605185793866,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 20:00,0.7956061637958083,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 21:00,0.82005824380545,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 22:00,0.8144962367755147,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0
04/01/2017 23:00,0.8615975946054796,0.49999999999999994,0.8660254037844387,0,0,1,0,0,0,0

@bhack Here, some lines of it.

bhack commented 1 year ago

You code is not running name 'tuner' is not defined.

Please update and test your gist and don't import keras directly but use keas from TF import tensorflow.keras

muriloasouza commented 1 year ago

@bhack Sorry for that, there was an indentation error when i copy/paste in the gist. Also changed all the direct import keras stuff.

bhack commented 1 year ago

https://github.com/tensorflow/addons/blob/4fda3d4a9c298096edd3da36244a3de8355c88e5/tensorflow_addons/rnn/esn_cell.py#L146-L152

Can you try to cast tf.cast(self.spectral_radius, dtype) on these lines? If it works can you submit a PR?

muriloasouza commented 1 year ago

@bhack Not sure if this is what i had to do, but i changed in the esn_cell.py file in my tensorflow_addons/rnn folder, both self.spectral_radius at lines 146 and 152 to tf.cast(self.spectral_radius, dtype). And it works, now it is running.

But i am sorry, i have no idea what a PR is or how to do it (kind of new to Github).

bhack commented 1 year ago

But i am sorry, i have no idea what a PR is or how to do it (kind of new to Github).

Check this https://docs.github.com/en/desktop/contributing-and-collaborating-using-github-desktop/working-with-your-remote-repository-on-github-or-github-enterprise/creating-an-issue-or-pull-request#creating-a-pull-request

bhack commented 1 year ago

Just a reminder with Keras tuners floats are double https://github.com/keras-team/keras-tuner/commit/85521a37dfc2d3e36d1f9f99de93d590eb4e31b1