rdevooght / sequence-based-recommendations

MIT License
371 stars 126 forks source link

RuntimeError: generator raised StopIteration #17

Open n0obcoder opened 4 years ago

n0obcoder commented 4 years ago

i am getting the following error while training the model using the command

python train.py -d path/to/dataset/
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 405, in _gen_mini_batch
    sequence, user_id = next(sequence_generator)
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "train.py", line 60, in <module>
    main()
  File "train.py", line 57, in main
    validation_metrics=args.metrics.split(','))
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 324, in train
    metrics = self._compute_validation_metrics(metrics)
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 365, in _compute_validation_metrics
    for batch_input, goal in self._gen_mini_batch(self.dataset.validation_set(epochs=1), test=True):
RuntimeError: generator raised StopIteration

What might be causing this error?

n0obcoder commented 4 years ago

i noticed that this error is occuring when all the 500 test examples are yielded and there is calling the next() method on this causes this error. But i still don't know how can i resolve this error.

n0obcoder commented 4 years ago

i fixed it by the help of https://stackoverflow.com/questions/51700960/runtimeerror-generator-raised-stopiteration-every-time-i-try-to-run-app

i changed

while True:
    j = 0
    sequences = []
    batch_size = self.batch_size
    if test:
        batch_size = 1
    while j < batch_size:

        sequence, user_id = next(sequence_generator)

        # print('sequence: ', sequence)
        # print('user_id: ', user_id, 'test: ', test)
        # pdb.set_trace()

        # finds the lengths of the different subsequences
        if not test:
            seq_lengths = sorted(random.sample(range(2, len(sequence)), min([batch_size - j, len(sequence) - 2, max_reuse_sequence])))
        else:
            seq_lengths = [int(len(sequence) / 2)] 

        skipped_seq = 0
        for l in seq_lengths:
            target = self.target_selection(sequence[l:], test=test)
            if len(target) == 0:
                skipped_seq += 1
                continue
            start = max(0, l - self.max_length) # sequences cannot be longer than self.max_lenght
            sequences.append([user_id, sequence[start:l], target])

        j += len(seq_lengths) - skipped_seq

    if test:
        # sequence[seq_lengths[0]:] is the sequence (ratings included) GT here
        # [i[0] for i in sequence[seq_lengths[0]:]]  is the GT here (ratings excluded)
        yield self._prepare_input(sequences), [i[0] for i in sequence[seq_lengths[0]:]] 
    else:
        yield self._prepare_input(sequences)

to

while True:
    try:
        j = 0
        sequences = []
        batch_size = self.batch_size
        if test:
            batch_size = 1
        while j < batch_size:

            sequence, user_id = next(sequence_generator)

            # print('sequence: ', sequence)
            # print('user_id: ', user_id, 'test: ', test)
            # pdb.set_trace()

            # finds the lengths of the different subsequences
            if not test:
                seq_lengths = sorted(random.sample(range(2, len(sequence)), min([batch_size - j, len(sequence) - 2, max_reuse_sequence])))
            else:
                seq_lengths = [int(len(sequence) / 2)] 

            skipped_seq = 0
            for l in seq_lengths:
                target = self.target_selection(sequence[l:], test=test)
                if len(target) == 0:
                    skipped_seq += 1
                    continue
                start = max(0, l - self.max_length) # sequences cannot be longer than self.max_lenght
                sequences.append([user_id, sequence[start:l], target])

            j += len(seq_lengths) - skipped_seq

        if test:
            # sequence[seq_lengths[0]:] is the sequence (ratings included) GT here
            # [i[0] for i in sequence[seq_lengths[0]:]]  is the GT here (ratings excluded)
            yield self._prepare_input(sequences), [i[0] for i in sequence[seq_lengths[0]:]] 
        else:
            yield self._prepare_input(sequences)
    except StopIteration:
        return