rpryzant / delete_retrieve_generate

PyTorch implementation of the Delete, Retrieve Generate style transfer algorithm
MIT License
132 stars 26 forks source link

Only works for parallel sentences #12

Closed TanmayParekh closed 4 years ago

TanmayParekh commented 4 years ago

The code and system supposedly only works if we have parallel sentences. It throws an error if the length of sentences in the source and target don't match. This was definitely not the case for Delete, Retrieve, Generate.

rpryzant commented 4 years ago

Hmm can you give me the data/config you used and stack trace?

I just confirmed it works for the example yelp data which is not parallel.

TanmayParekh commented 4 years ago

We ran the system for training yelp from positive to negative sentiment. More specifically, the config/data used was as follows (We did make the vocab files again for training the opposite side)

"data": {
    "src": "data/yelp/sentiment.train.1",
    "tgt": "data/yelp/sentiment.train.0",
    "src_test": "data/yelp/reference.test.1",
    "tgt_test": "data/yelp/reference.test.0",
    "src_vocab": "data/yelp/vocab",
    "tgt_vocab": "data/yelp/vocab",
    "share_vocab": true,
    "attribute_vocab": "data/yelp/ngram.15.attribute",
    "ngram_attributes": true,
    "batch_size": 256,
    "max_len": 50,
    "working_dir": "working_dir"
  }

Since the code takes a long time to throw an error, here is a sample stack trace/error which was produced while running non-parallel data (more importantly the case when number of source sentences are more than the number of target sentences)

Traceback (most recent call last):
  File "train.py", line 160, in <module>
    src, tgt, i, batch_size, max_length, config['model']['model_type'])
  File "/home2/tparekh/politeness_project/delete_retrieve_generate/src/data.py", line 291, in minibatch
    in_dataset['content'], in_dataset['tok2id'], idx, batch_size, max_len, sort=True)
  File "/home2/tparekh/politeness_project/delete_retrieve_generate/src/data.py", line 236, in get_minibatch
    max_len = max(lens)
ValueError: max() arg is an empty sequence
rpryzant commented 4 years ago

Thanks for the info.

Fixed in 136abd7ddb3be8689e95ee81e5eb138800619dc0