allenai / unifiedqa

UnifiedQA: Crossing Format Boundaries With a Single QA System
https://arxiv.org/abs/2005.00700
Apache License 2.0
428 stars 43 forks source link

ValueError: exactly one of keep_prob and rate should be set #34

Closed ayush714 closed 2 years ago

ayush714 commented 3 years ago

I am running below code:-

    MODEL_SIZE = "large"
    BASE_PRETRAINED_DIR = "gs://unifiedqa/models/large"
    PRETRAINED_DIR = BASE_PRETRAINED_DIR
    MODEL_DIR = os.path.join(MODEL_DIR, MODEL_SIZE)

    model_parallelism, train_batch_size, keep_checkpoint_max = {
        "small": (1, 256, 16),
        "base": (2, 128, 8),
        "large": (8, 64, 4),
        "3B": (8, 16, 1),
        "11B": (8, 16, 1)}[MODEL_SIZE]
    tf.io.gfile.makedirs(MODEL_DIR)
    ON_CLOUD = False
    model = t5.models.MtfModel(
        model_dir=MODEL_DIR,
        tpu=None,
        model_parallelism=model_parallelism,
        batch_size=train_batch_size,
        sequence_length={"inputs": 128, "targets": 32},
        learning_rate_schedule=0.003,
        save_checkpoints_steps=5000,
        keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
        iterations_per_loop=100,
    )
    FINETUNE_STEPS = 9

    logInfo("Started Training the model")
    start = time()
    model.finetune(
        mixture_or_task_name="qa_t5_meshs",
        pretrained_model_dir=PRETRAINED_DIR,
        finetune_steps=FINETUNE_STEPS
    )
    logInfo("Completed model training.", time_taken=time() - start)

and getting

ValueError: exactly one of keep_prob and rate should be set
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/ml-experimentation-dev/research/qna_finetuning/code/run_t5mesh_train.py in <module>
     127 logInfo("Started Training the model")
     128 start = time()
---> 129 model.finetune(
     130     mixture_or_task_name="qa_t5_meshs",
     131     pretrained_model_dir=PRETRAINED_DIR,
danyaljj commented 2 years ago

This is a good question to ask in the T5 repository.