tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.5k stars 3.49k forks source link

Translate English to Chinese with own data, but failed. #1173

Closed lowblung closed 5 years ago

lowblung commented 6 years ago

Description

I will to use my own set of data, but it is not working. Could anyone can help with this issue?

Environment information

OS: CentOS

$ pip freeze | grep tensor
tensor2tensor==1.6.2
tensorboard==1.8.0
tensorflow-gpu==1.8.0
tensorflow-serving-api-python3==1.7.0
tensorflow-tensorboard==0.4.0

$ python -V
  from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.utils import registry

import tensorflow as tf

FLAGS = tf.flags.FLAGS

# End-of-sentence marker.
EOS = text_encoder.EOS_ID

_IPO_TRAIN_DATASETS = ["train.final.en","train.final.zh"]

_IPO_TEST_DATASETS = ["dev.final.en","dev.final.zh"]

def get_filename(dataset):
  return dataset[0][0].split("/")[-1]

@registry.register_problem
class TranslateRon32k(translate.TranslateProblem):
  """Problem spec for WMT En-Zh translation.
  Attempts to use full training dataset, which needs website
  registration and downloaded manually from official sources:
  CWMT:
    - http://nlp.nju.edu.cn/cwmt-wmt/
    - Website contains instructions for FTP server access.
    - You'll need to download CASIA, CASICT, DATUM2015, DATUM2017,
        NEU datasets
  UN Parallel Corpus:
    - https://conferences.unite.un.org/UNCorpus
    - You'll need to register your to download the dataset.
  NOTE: place into tmp directory e.g. /tmp/t2t_datagen/dataset.tgz
  """

  @property
  def approx_vocab_size(self):
    return 2**15  # 32k

  @property
  def source_vocab_name(self):
    return "%s.en" % self.vocab_filename

  @property
  def target_vocab_name(self):
    return "%s.zh" % self.vocab_filename

  def get_training_dataset(self, tmp_dir):
    """UN Parallel Corpus and CWMT Corpus need to be downloaded manually.
    Append to training dataset if available
    Args:
      tmp_dir: path to temporary dir with the data in it.
    Returns:
      paths
    """
    full_dataset = _NC_TRAIN_DATASETS
    for dataset in [_IPO_TRAIN_DATASETS, _IPO_TRAIN_DATASETS]:
      filename = get_filename(dataset)
      tmp_filepath = os.path.join(tmp_dir, filename)
      if tf.gfile.Exists(tmp_filepath):
        full_dataset += dataset
      else:
        tf.logging.info("[TranslateEzhWmt] dataset incomplete, you need to "
                        "manually download %s" % filename)
    return full_dataset

  def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
    train = dataset_split == problem.DatasetSplit.TRAIN
    train_dataset = self.get_training_dataset(tmp_dir)
    datasets = train_dataset if train else _NC_TEST_DATASETS
    source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
    target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
    source_vocab = generator_utils.get_or_generate_vocab(
        data_dir,
        tmp_dir,
        self.source_vocab_name,
        self.approx_vocab_size,
        source_datasets,
        file_byte_budget=1e8)
    target_vocab = generator_utils.get_or_generate_vocab(
        data_dir,
        tmp_dir,
        self.target_vocab_name,
        self.approx_vocab_size,
        target_datasets,
        file_byte_budget=1e8)
    tag = "train" if train else "dev"
    filename_base = "wmt_enzh_%sk_tok_%s" % (self.approx_vocab_size, tag)
    data_path = translate.compile_data(tmp_dir, datasets, filename_base)
    return text_problems.text2text_generate_encoded(
        text_problems.text2text_txt_iterator(data_path + ".lang1",
                                             data_path + ".lang2"),
        source_vocab, target_vocab)

  def feature_encoders(self, data_dir):
    source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
    target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
    source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
    target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
    return {
        "inputs": source_token,
        "targets": target_token,
    }

@registry.register_problem
class TranslateRon8k(TranslateRon32k):
  """Problem spec for WMT En-Zh translation.
  This is far from being the real WMT17 task - only toyset here
  """

  @property
  def approx_vocab_size(self):
    return 2**13  # 8192

  @property
  def dataset_splits(self):
    return [
        {
            "split": problem.DatasetSplit.TRAIN,
            "shards": 10,  # this is a small dataset
        },
        {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }
    ]

  def get_training_dataset(self, tmp_dir):
    """Uses only News Commentary Dataset for training."""
    return _IPO_TRAIN_DATASETS

### For bugs: reproduction and error logs

Steps to reproduce:

USR_DIR=$HOME/t2t_usr PROBLEM=translate_ron8k DATA_DIR=$HOME/t2t_data

TMP_DIR=$HOME/tmp/t2t_datagen

mkdir -p $DATA_DIR $TMP_DIR $USR_DIR

t2t-datagen \

--t2t_usr_dir=$HOME/t2t_usr \ --data_dir=$HOME/t2t_data \ --tmp_dir=$HOME/tmp/t2t_datagen \ --problem=$PROBLEM


# Error logs:

/usr/local/tensorflow/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

INFO:tensorflow:Importing user module t2t_usr from path /data/klching
[2018-10-26 17:14:59,362] Importing user module t2t_usr from path /data/klching
INFO:tensorflow:Generating problems:
    translate:
      * translate_ron8k
[2018-10-26 17:14:59,364] Generating problems:
    translate:
      * translate_ron8k
INFO:tensorflow:Generating data for translate_ron8k.
[2018-10-26 17:14:59,365] Generating data for translate_ron8k.
Traceback (most recent call last):
  File "/usr/local/tensorflow/bin/t2t-datagen", line 27, in <module>
    tf.app.run()
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "/usr/local/tensorflow/bin/t2t-datagen", line 23, in main
    t2t_datagen.main(argv)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensor2tensor/bin/t2t_datagen.py", line 184, in main
    generate_data_for_registered_problem(problem)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensor2tensor/bin/t2t_datagen.py", line 233, in generate_data_for_registered_problem
    problem.generate_data(data_dir, tmp_dir, task_id)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensor2tensor/data_generators/text_problems.py", line 283, in generate_data
    self.generate_encoded_samples(data_dir, tmp_dir, split)), paths)
  File "/data/klching/t2t_usr/translate_ron.py", line 101, in generate_encoded_samples
    target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
  File "/data/klching/t2t_usr/translate_ron.py", line 101, in <listcomp>
    target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
IndexError: string index out of range
afrozenator commented 6 years ago

I don't see _NC_TRAIN_DATASETS being defined anywhere? Also the problem seems to be in the way you are trying to create the data, i.e. in the line:

   target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]

it is either not finding item[0] unlikely, or item[1] also unlikely, or item[1][1] most likely -- you can try to run in code itself in the python interpreter and try again?

In the worst case, you can use Text2TextTmpDir problem and make line by line data and use it to train.

lowblung commented 5 years ago

I have added the link as the item[0] to run. It is kind of "working" to run the tensor2tensor.

_IPO_TEST_DATASETS = [[ "https://s3-us-west-2.amazonaws.com/twairball.wmt17.zh-en/cwmt.tgz", ["dev.final.en","dev.final.zh"] ]]

afrozenator commented 5 years ago

If I understand it is working now? Let me know if it isn't working. Closing this now, feel free to re-open if there are issues, ok?