google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

Issue when running training_loop.run(2000) - message StopIteration in next_batch(self) #1779

Open SisypheGeek opened 1 year ago

SisypheGeek commented 1 year ago

Hi, I am facing an issue and I was hoping someone could tell me how to solve it. I am trying to use a neural network sentiment analysis like it is show in the Trax introduction but using twitter dataset. But when I run the training_loop function I am facing a StopIteration and I am in a point where I don't know what I can do to solve it.

Apart from the data there is nothing much different from the Trax Quick Intro. Please see code below. Everything work except the training_loop.run(2000).

Environment information

I am running the code in a colab notebook.

here is the code

# Steps to reproduce:
from nltk.corpus import twitter_samples
import os
import shutil
import random as rnd
import numpy as np
import trax
import trax.fastmath.numpy as np
from trax import layers as tl
from trax import fastmath
from sklearn.model_selection import train_test_split

def train_val_test_split():
  all_positive_tweets = twitter_samples.strings('positive_tweets.json')
  all_negative_tweets = twitter_samples.strings('negative_tweets.json')
  all_tweet = all_positive_tweets + all_negative_tweets
  label = np.append(np.ones((len(all_positive_tweets), 1)), np.zeros((len(all_negative_tweets), 1)), axis=0)
  x_train, x_test, y_train, y_test = train_test_split(all_tweet, label, test_size=0.1, random_state=42,
                                                        shuffle=True, stratify=label)
  x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42,
                                                        shuffle=True, stratify=y_train)
  return x_train, x_val, y_train, y_val, x_test, y_test

x_train2, x_val2, y_train2, y_val2, x_test2, y_test2 = train_val_test_split()

def data_generator_training():
  for tweet2, label2 in zip(x_train2, y_train2):
    yield (tweet2.encode('utf-8'), label2[0])

def data_generator_eval():
  for tweet3, label3 in zip(x_val2, y_val2):
    yield (tweet3.encode('utf-8'), label3[0])

tweet_gen_training2 = data_generator_training()
tweet_gen_eval = data_generator_eval()

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(tweet_gen_training2)
eval_batches_stream = data_pipeline(tweet_gen_eval)

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

training_loop.run(2000)
# Error
training_loop.run(2000)
# Error logs:
Step      1: Total number of trainable weights: 2304514
Step      1: Ran 1 train steps in 1.25 secs
Step      1: train WeightedCategoryCrossEntropy |  0.70142251
/usr/local/lib/python3.10/dist-packages/trax/supervised/training.py:1249: FutureWarning: GzipFile was opened for writing, but this will change in future Python releases.  Specify the mode argument for opening it for writing.
  with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-93-d14f18d9bf82> in <cell line: 1>()
----> 1 training_loop.run(100)

2 frames
/usr/local/lib/python3.10/dist-packages/trax/supervised/training.py in next_batch(self)
   1188   def next_batch(self):
   1189     """Returns one batch of labeled data: a tuple of input(s) plus label."""
-> 1190     return next(self._labeled_data)
   1191 
   1192   @property

StopIteration: 

I Just cannot see what I am missing for getting this issue. Any help would much appreciated. Thank you.