allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.77k stars 2.25k forks source link

When retrain bert-based coreference resolution, the coreference recall is very low. How to solve it? #4853

Closed yqw-vicki closed 3 years ago

yqw-vicki commented 3 years ago

Hi, I am now using coreference resolver to retrain bert-based coreference model with ontonote5.0. But the result is not satisfied.Even if the mention recall is very high, but the coreference recall still low. The below is the result: image And my training code is show here:

def build_trainer(
    model: Model,
    serialization_dir: str,
    train_loader: DataLoader,
    dev_loader: DataLoader
) -> Trainer:

    bert_embedder = []
    for n, p in model.named_parameters():
        if n.startswith('_text_field_embedder'):
            bert_embedder.append(n)

    parameters = [
        (n, p)
        for n, p in model.named_parameters() if p.requires_grad
    ]

    group_parameter = [
        (bert_embedder, {"lr": 2e-4})
    ]

    optimizer = AdamOptimizer(parameters, group_parameter, lr=1e-5)
    checkpoint = Checkpointer(serialization_dir=serialization_dir)
    weight_decay = MultiStepLearningRateScheduler(optimizer, list(range(0, len(train_loader.dataset.instances))))
    trainer = GradientDescentTrainer(
        model=model,
        serialization_dir=serialization_dir,
        checkpointer=checkpoint,
        data_loader=train_loader,
        validation_data_loader=dev_loader,
        patience=10,
        num_epochs=20,
        cuda_device=device,
        learning_rate_scheduler=weight_decay,
        optimizer=optimizer
    )
    return trainer
dirpath = os.path.abspath(os.path.dirname(os.getcwd()))
train_filepath = os.path.join(dirpath, "conll-formatted-ontonotes-5.0/data/train")
valid_filepath = os.path.join(dirpath, "conll-formatted-ontonotes-5.0/data/development")
transformer_model = 'bert-base-cased'
token_indexer = PretrainedTransformerMismatchedIndexer(model_name=transformer_model, max_length=512)
reader = ConllCorefReader(30, {'bert_tokens': token_indexer}, max_sentences=10)
train_dataset = reader.read(train_filepath)
validation_dataset = reader.read(valid_filepath)
vocab = Vocabulary()
train_dataset.index_with(vocab)
validation_dataset.index_with(vocab)
train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=allennlp_collate)
dev_loader = DataLoader(validation_dataset, batch_size=1, collate_fn=allennlp_collate)
embedding = PretrainedTransformerMismatchedEmbedder(transformer_model, max_length=512,
                                                    last_layer_only=True, gradient_checkpointing=True)
embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
encoder = LstmSeq2SeqEncoder(embedder.get_output_dim(), 200, 1, dropout=0.2, bidirectional=True)
span_dimension = 2*encoder.get_output_dim()+embedder.get_output_dim()+20
mention_feedforward = FeedForward(span_dimension, 2, [150, 150], torch.nn.ReLU())
antecedent_dimension = 3*span_dimension+20
antecedent_feedforward = FeedForward(antecedent_dimension, 2, [150, 150], torch.nn.ReLU())
corefer = CoreferenceResolver(vocab, text_field_embedder=embedder, context_layer=encoder,
                            mention_feedforward=mention_feedforward, antecedent_feedforward=antecedent_feedforward,
                            feature_size=20, max_span_width=30, spans_per_word=0.4, max_antecedents=50).to(device)

Is there any problem in my training code?

dirkgr commented 3 years ago

It sounds to me like you're trying to replicate the experiment in this config file: https://github.com/allenai/allennlp-models/blob/master/training_config/coref/coref_bert_lstm.jsonnet. Can you adapt your code to use the parameters from this config, and try again? I didn't look to find all the differences, but your learning rate scheduler is different, and so is the number of epochs.

yqw-vicki commented 3 years ago

Thank you very much! I follow the config to modify my model again, now it work.