jiangsutx / SRN-Deblur

Repository for Scale-recurrent Network for Deep Image Deblurring
http://www.xtao.website/projects/srndeblur/srndeblur_cvpr18.pdf
MIT License
709 stars 184 forks source link

lstm model #78

Open noranali opened 1 year ago

noranali commented 1 year ago

@jiangsutx , thank you for your contribution please help me when i try to train lstm model , i have error

/content/drive/My Drive/SRN-Deblur-master/models/model.py in generator(self, inputs, reuse, scope) 79 with tf.compat.v1.variable_scope('LSTM'): 80 cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128) ---> 81 rnn_state =cell.zero_state(batch_size=self.batch_size,dtype=tf.float32) 82 83 x_unwrap = []

/content/drive/My Drive/SRN-Deblur-master/util/BasicConvLSTMCell.py in zero_state(self, batch_size, dtype) 19 shape = self.shape 20 num_features = self.num_features ---> 21 zeros = tf.compat.v1.zeros([batch_size, shape[0], shape[1], num_features * 2]) 22 return zeros 23

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.traceback) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_util.py in _AssertCompatible(values, dtype) 330 raise TypeError("Expected any non-tensor type, but got a tensor instead.") 331 else: --> 332 raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type " 333 f"'{type(mismatch).name}'.") 334

TypeError: Expected int32, but got 64.0 of type 'float'.

noranali commented 1 year ago

@jiangsutx , thank you for your contribution please help me when i try to train lstm model , i have error

/content/drive/My Drive/SRN-Deblur-master/models/model.py in generator(self, inputs, reuse, scope) 79 with tf.compat.v1.variable_scope('LSTM'): 80 cell = BasicConvLSTMCell([h / 4, w / 4], [3, 3], 128) ---> 81 rnn_state =cell.zero_state(batch_size=self.batch_size,dtype=tf.float32) 82 83 x_unwrap = []

/content/drive/My Drive/SRN-Deblur-master/util/BasicConvLSTMCell.py in zero_state(self, batch_size, dtype) 19 shape = self.shape 20 num_features = self.num_features ---> 21 zeros = tf.compat.v1.zeros([batch_size, shape[0], shape[1], num_features * 2]) 22 return zeros 23

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.traceback**) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_util.py in _AssertCompatible(values, dtype) 330 raise TypeError("Expected any non-tensor type, but got a tensor instead.") 331 else: --> 332 raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type " 333 f"'{type(mismatch).name}'.") 334

TypeError: Expected int32, but got 64.0 of type 'float'.

please help @jiangsutx