amaiya / ktrain

ktrain is a Python library that makes deep learning and AI more accessible and easier to apply
Apache License 2.0
1.23k stars 269 forks source link

XLNET model error when validating #484

Closed Jain-Abhilash closed 1 year ago

Jain-Abhilash commented 1 year ago

Hey, im getting the below error when I'm trying to call the validate function of learner:

train sequence lengths:                                                                                                                                                          
        mean : 88                                                                                                                                                                
        95percentile : 170                                                                                                                                                       
        99percentile : 214                                                                                                                                                       
Is Multi-Label? False                                                                                                                                                            
preprocessing test...                                                                                                                                                            
language: en                                                                                                                                                                     
test sequence lengths:                                                                                                                                                           
        mean : 88                                                                                                                                                                
        95percentile : 170                                                                                                                                                       
        99percentile : 214                                                                                                                                                       
   3220/Unknown - 93s 27ms/stepTraceback (most recent call last):                                                                                                                
  File "xlnetretrain.py", line 31, in <module>                                                                                                                                   
    learner.validate(class_names=t.get_classes())                                                                                                                                

File "/usr/local/lib/python3.8/dist-packages/ktrain/core.py", line 168, in validate                                                                                            
    y_pred = self.predict(val_data=val)                                                                                                                                          
  File "/usr/local/lib/python3.8/dist-packages/ktrain/text/learner.py", line 186, in predict                                                                                     
    preds = self.model.predict(self._prepare(val, train=False))                                                                                                                  
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler                                                                        
    raise e.with_traceback(filtered_tb) from None                                                                                                                                
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 52, in quick_execute                                                                    
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,                                                                                                       
tensorflow.python.framework.errors_impl.ResourceExhaustedError: Graph execution error:
Detected at node 'tfxl_net_for_sequence_classification_1/transformer/layer_._1/rel_attn/einsum_3/Einsum' defined at (most recent call last):
    File "xlnetretrain.py", line 31, in <module>
      learner.validate(class_names=t.get_classes())
    File "/usr/local/lib/python3.8/dist-packages/ktrain/core.py", line 168, in validate
      y_pred = self.predict(val_data=val)
    File "/usr/local/lib/python3.8/dist-packages/ktrain/text/learner.py", line 186, in predict
      preds = self.model.predict(self._prepare(val, train=False))
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2350, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2137, in predict_function
      return step_function(self, iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2123, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2111, in run_step
      outputs = model.predict_step(data)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2079, in predict_step
      return self(x, training=False)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_tf_utils.py", line 1399, in run_call_with_unpacked_inputs                                         [11/912]
      loss="passthrough",
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 1413, in call
      transformer_outputs = self.transformer(
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_tf_utils.py", line 1399, in run_call_with_unpacked_inputs
      loss="passthrough",
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 740, in call
      for i, layer_module in enumerate(self.layer):
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 747, in call
      outputs = layer_module(
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 378, in call
      outputs = self.rel_attn(
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 204, in call
      if g is not None:
    File "/usr/local/lib/python3.8/dist-packages/transformers/models/xlnet/modeling_tf_xlnet.py", line 302, in call
      k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
Node: 'tfxl_net_for_sequence_classification_1/transformer/layer_._1/rel_attn/einsum_3/Einsum'
OOM when allocating tensor with shape[512,1024] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
         [[{{node tfxl_net_for_sequence_classification_1/transformer/layer_._1/rel_attn/einsum_3/Einsum}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available whe
n running in Eager mode.
 [Op:__inference_predict_function_28628]

Training went well with no issues, I'm using A100 GPUS, with the latest K-train and Tensorflow and this issue is only with XLNET and when i ran this with distilBert there were no issues at all My code is:

model = t.get_classifier()
learner = ktrain.get_learner(model,train_data=trn, val_data=val, eval_batch_size=1, batch_size=64)
learner.model.load_weights('models_xlnet/weights-04.hdf5')

learner.validate(class_names=t.get_classes())
learner.view_top_losses(n=10, preproc=t)
predictor = ktrain.get_predictor(learner.model, preproc=t)
predictor.save('models_xlnet/predictor')
amaiya commented 1 year ago

These are out-of-memory (OOM) errors. The OOM errors seem to be happening when making predictions with xlnet model on a single large set of examples (regardless of the batch size set). XLNET has some sort of issue when invoking model.predict on a large set of examples. Since it is only happening with XLNET and not any other models (e.g., BERT, ROBERTA), it seems like it may be an issue with either transformers or TensorFlow and not ktrain.

In any case, the workaround is to batchify the dataset yourself and feed the batches to predict. Here is a self-contained example of the workaround (where STEP 3 is the actual workaround):

# STEP 1:  load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)

# STEP 2:  build and train XLNet
import ktrain
from ktrain import text
MODEL_NAME = 'xlnet-base-cased'
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=None, batch_size=6)
learner.fit_onecycle(5e-5, 1)

# STEP 3:  make predictions
p = ktrain.get_predictor(learner.model, t)
from ktrain import utils as U
batches = U.batchify(x_test, 32)
preds = []
for batch in batches:
    preds.extend(p.predict(batch))

# STEP 4: ground truth
ground_truth = [train_b.target_names[y] for y in y_test]

# STEP 5: evaluate
from sklearn.metrics import classification_report
print(classification_report(ground_truth, preds))

# OUTPUT
#                        precision    recall  f1-score   support
#
#           alt.atheism       0.83      0.86      0.85       319
#         comp.graphics       0.98      0.96      0.97       38
#               sci.med       0.94      0.97      0.96       396
#soc.religion.christian       0.92      0.88      0.90       398

#              accuracy                           0.92      1502
#             macro avg       0.92      0.92      0.92      1502
#          weighted avg       0.92      0.92      0.92      1502

P.S. ktrain>=0.33.4 will suppress the progress bars automatically when running predict in a for loop