BrikerMan / Kashgari

Kashgari is a production-level NLP Transfer learning framework built on top of tf.keras for text-labeling and text-classification, includes Word2Vec, BERT, and GPT2 Language Embedding.
http://kashgari.readthedocs.io/
Apache License 2.0
2.39k stars 441 forks source link

能给GPT2 Embedding添加一个设置trainable=False的功能吗吗?[Feature request] #292

Closed josenxx closed 4 years ago

josenxx commented 4 years ago

我想去用GPT2 Embedding但是不能把trainable设置成false,于是程序无法运行。 报错是显存不够,就把显卡从k80换成了p6000,但还是不行。 如果能把trainable设置成false大概就不会报错,因为BERTEmbedding在k80上能正常运行 以下是在k80上的报错信息

ResourceExhaustedError Traceback (most recent call last)

in 13 model = BiLSTM_Model(embedding) 14 start = time.time() ---> 15 model.fit(train_x, train_y, valid_x, valid_y, epochs=50) 16 end = time.time() 17 elapsed = end - start /usr/local/lib/python3.6/dist-packages/kashgari/tasks/base_model.py in fit(self, x_train, y_train, x_validate, y_validate, batch_size, epochs, callbacks, fit_kwargs, shuffle) 308 validation_steps=validation_steps, 309 callbacks=callbacks, --> 310 **fit_kwargs) 311 312 def fit_without_generator(self, /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 1431 shuffle=shuffle, 1432 initial_epoch=initial_epoch, -> 1433 steps_name='steps_per_epoch') 1434 1435 def evaluate_generator(self, /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs) 262 263 is_deferred = not model._is_compiled --> 264 batch_outs = batch_function(*batch_data) 265 if not isinstance(batch_outs, list): 266 batch_outs = [batch_outs] /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics) 1173 self._update_sample_weight_modes(sample_weights=sample_weights) 1174 self._make_train_function() -> 1175 outputs = self.train_function(ins) # pylint: disable=not-callable 1176 1177 if reset_metrics: /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in __call__(self, inputs) 3290 3291 fetched = self._callable_fn(*array_vals, -> 3292 run_metadata=self.run_metadata) 3293 self._call_fetch_callbacks(fetched[-len(self._fetches):]) 3294 output_structure = nest.pack_sequence_as( /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs) 1456 ret = tf_session.TF_SessionRunCallable(self._session._session, 1457 self._handle, args, -> 1458 run_metadata_ptr) 1459 if run_metadata: 1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) ResourceExhaustedError: 2 root error(s) found. (0) Resource exhausted: OOM when allocating tensor with shape[64,59,50257] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [[{{node Output/truediv}}]] Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. [[metrics_2/acc/Identity/_3061]] Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. (1) Resource exhausted: OOM when allocating tensor with shape[64,59,50257] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [[{{node Output/truediv}}]] Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. 0 successful operations. 0 derived errors ignored. 以下是在p6000上的报错信息 --------------------------------------------------------------------------- InternalError Traceback (most recent call last) in 13 model = BiLSTM_Model(embedding) 14 start = time.time() ---> 15 model.fit(train_x, train_y, valid_x, valid_y, epochs=50) 16 end = time.time() 17 elapsed = end - start /usr/local/lib/python3.6/dist-packages/kashgari/tasks/base_model.py in fit(self, x_train, y_train, x_validate, y_validate, batch_size, epochs, callbacks, fit_kwargs, shuffle) 308 validation_steps=validation_steps, 309 callbacks=callbacks, --> 310 **fit_kwargs) 311 312 def fit_without_generator(self, /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 1431 shuffle=shuffle, 1432 initial_epoch=initial_epoch, -> 1433 steps_name='steps_per_epoch') 1434 1435 def evaluate_generator(self, /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs) 262 263 is_deferred = not model._is_compiled --> 264 batch_outs = batch_function(*batch_data) 265 if not isinstance(batch_outs, list): 266 batch_outs = [batch_outs] /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics) 1173 self._update_sample_weight_modes(sample_weights=sample_weights) 1174 self._make_train_function() -> 1175 outputs = self.train_function(ins) # pylint: disable=not-callable 1176 1177 if reset_metrics: /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in __call__(self, inputs) 3290 3291 fetched = self._callable_fn(*array_vals, -> 3292 run_metadata=self.run_metadata) 3293 self._call_fetch_callbacks(fetched[-len(self._fetches):]) 3294 output_structure = nest.pack_sequence_as( /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs) 1456 ret = tf_session.TF_SessionRunCallable(self._session._session, 1457 self._handle, args, -> 1458 run_metadata_ptr) 1459 if run_metadata: 1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) InternalError: Failed to call ThenRnnBackward with model config: [rnn_mode, rnn_input_mode, rnn_direction_mode]: 2, 0, 0 , [num_layers, input_size, num_units, dir_count, max_seq_length, batch_size]: [1, 50257, 128, 1, 87, 64] [[{{node Adam_1/gradients/layer_blstm/CudnnRNN_grad/CudnnRNNBackprop}}]]
stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.