facebookresearch / ParlAI

A framework for training and evaluating AI models on a variety of openly available dialogue datasets.
https://parl.ai
MIT License
10.49k stars 2.09k forks source link

Validation & Testing For Dataset Extending ParlAIDialogTeacher #4278

Closed michaelyma12 closed 2 years ago

michaelyma12 commented 2 years ago

Hello all,

Thanks for this amazing library. Presently trying to fine-tune the zoo:tutorial_transformer_generator/model from this tutorial on my own dataset.

Here's a sample of my data:

text:How was work??       labels:Not bad, how was your day?
text:Very good!!  labels:That’s good!
text:Doing bad stuff     labels:Oh yeah????
text:Oh yeah!    labels:Like what
text:you don't wanna know   labels:oh i bet i do!        episode_done:True

I'm running the following code:

@register_teacher('default_teacher')
class DefaultTeacher(ParlAIDialogTeacher):

    def __init__(self, opt, shared=None):
        """

        :param opt:
        :param shared:
        """
        opt = copy.deepcopy(opt)
        opt['parlaidialogteacher_datafile'] = os.path.join('artifacts', 'mvp', 'datasets', '{}.txt'.format(opt['datatype'].split(':')[0]))
        super().__init__(opt, shared)

TrainModel.main(
    # similar to before
    task='default_teacher',
    model='transformer/generator',
    model_file='from_pretrained/model',

    # initialize with a pretrained model
    init_model='zoo:tutorial_transformer_generator/model',

    # arguments we get from the pretrained model.
    # Unfortunately, these must be looked up separately for each model.
    n_heads=16, n_layers=8, n_positions=512, text_truncate=512,
    label_truncate=128, ffn_size=2048, embedding_size=512,
    activation='gelu', variant='xlm',
    dict_lower=True, dict_tokenizer='bpe',
    dict_file='zoo:tutorial_transformer_generator/model.dict',
    learn_positional_embeddings=True,

    # some training arguments, specific to this fine-tuning
    # use a small learning rate with ADAM optimizer
    lr=1e-5, optimizer='adam',
    warmup_updates=100,
    # early stopping on perplexity
    validation_metric='ppl',
    # train at most 10 minutes, and validate every 0.25 epochs
    max_train_time=600, validation_every_n_epochs=0.25,
    # depend on your gpu. If you have a V100, this is good
    batchsize=12,
    fp16=True,
    fp16_impl='mem_efficient',
    # speeds up validation
    skip_generation=True,
    # helps us cram more examples into our gpu at a time
    dynamic_batching='full'
)

And receiving the following error:

Traceback (most recent call last):
  File "/Users/mma/opt/anaconda3/envs/charmander/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-34-6b1186e0bd29>", line 34, in <module>
    dynamic_batching='full'
  File "/Users/mma/tmg/ParlAI/parlai/core/script.py", line 127, in main
    return cls._run_kwargs(kwargs)
  File "/Users/mma/tmg/ParlAI/parlai/core/script.py", line 92, in _run_kwargs
    return cls._run_from_parser_and_opt(opt, parser)
  File "/Users/mma/tmg/ParlAI/parlai/core/script.py", line 108, in _run_from_parser_and_opt
    return script.run()
  File "/Users/mma/tmg/ParlAI/parlai/scripts/train_model.py", line 998, in run
    return self.train_loop.train()
  File "/Users/mma/tmg/ParlAI/parlai/scripts/train_model.py", line 958, in train
    valid_worlds, opt, 'valid', max_exs, write_log=True
  File "/Users/mma/tmg/ParlAI/parlai/scripts/train_model.py", line 651, in _run_eval
    task_report = self._run_single_eval(opt, v_world, max_exs_per_worker)
  File "/Users/mma/tmg/ParlAI/parlai/scripts/train_model.py", line 609, in _run_single_eval
    valid_world.parley()
  File "/Users/mma/tmg/ParlAI/parlai/core/worlds.py", line 1136, in parley
    obs = self.worlds[i].get_model_agent().observe(act)
  File "/Users/mma/tmg/ParlAI/parlai/core/torch_agent.py", line 1833, in observe
    self._validate_observe_invariants()
  File "/Users/mma/tmg/ParlAI/parlai/core/torch_agent.py", line 1938, in _validate_observe_invariants
    "Last observe() had a label, but no call to self_observe ever "
RuntimeError: Last observe() had a label, but no call to self_observe ever happened. You are likely making multiple observe() calls without a corresponding act(). This was changed in #2043. File a GitHub issue if you require assistance.

The training function works if I simply set: opt['parlaidialogteacher_datafile'] = os.path.join('artifacts', 'mvp', 'datasets', 'train.txt')

But that would run the validation & testing steps on the train set. Obviously I would like to have a specific dataset for each step.

Any insight would be appreciated - many thanks again.

michaelyma12 commented 2 years ago

Solved - for anyone viewing this in the future, I had to add the flag eval_dynamic_batching='off' to TrainModel.main().