lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
685 stars 146 forks source link

dataloader stucks when using multiprocessing #32

Closed zhollow closed 3 years ago

zhollow commented 3 years ago

The config file I use :

[1]
label = conll04_train
model_type = spert
#model_path = bert-base-cased
#tokenizer_path = bert-base-cased
model_path = data/models/conll04
tokenizer_path = data/models/conll04
train_path = data/datasets/conll04/conll04_train.json
valid_path = data/datasets/conll04/conll04_dev.json
types_path = data/datasets/conll04/conll04_types.json
train_batch_size = 2
eval_batch_size = 1
neg_entity_count = 100
neg_relation_count = 100
epochs = 20
lr = 5e-5
lr_warmup = 0.1
weight_decay = 0.01
max_grad_norm = 1.0
rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
store_predictions = true
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
final_eval = true
log_path = data/log/
save_path = data/save/

I only changed the model_path and tokenizer_path. When I ran python spert.py train --config configs/example_train.conf it stuck at the dataloader(see comments in the code):

# in spert_trainer.py _train_epoch
        for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch):  ##### stuck at this line
            model.train()   # never reached here
            batch = util.to_device(batch, self._device)

            # forward step
            entity_logits, rel_logits = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
                                              entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
                                              relations=batch['rels'], rel_masks=batch['rel_masks'])

            # compute loss and optimize parameters
            batch_loss = compute_loss.compute(entity_logits=entity_logits, rel_logits=rel_logits,
                                              rel_types=batch['rel_types'], entity_types=batch['entity_types'],
                                              entity_sample_masks=batch['entity_sample_masks'],
                                              rel_sample_masks=batch['rel_sample_masks'])

But when I changed the sampling_processes to 0, it worked, though slow.

Why did the execution stuck with sampling_processes=4?

I am using a CPU to train the model if that matters.

zhollow commented 3 years ago

Strange. After I changed sampling_processes to 0 and changed it back to 4, it started working.