dsindex / ntagger

reference pytorch code for named entity tagging
86 stars 13 forks source link

ValueError: the first two dimensions of emissions and mask must match #12

Closed appledora closed 2 years ago

appledora commented 2 years ago

Hello, I am trying to train on custom dataset using the default config-glove.json file. I have modified the dataset.py accordingly, following instructions mentioned in a previous bert-model related issue #11 . However, I am getting this error :

Epoch 0:   0%|          | 0/5816 [00:00<?, ?it/s]/cm/local/apps/python37/lib/python3.7/site-packages/torchcrf/__init__.py:305: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at  ../aten/src/ATen/native/TensorCompare.cpp:328.)
  score = torch.where(mask[i].unsqueeze(1), next_score, score)

Epoch 0:   0%|          | 0/5816 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "train.py", line 762, in <module>
    main()
  File "train.py", line 759, in main
    train(args)
  File "train.py", line 594, in train
    eval_loss, eval_f1, best_eval_f1 = train_epoch(model, config, train_loader, valid_loader, epoch_i, best_eval_f1)
  File "train.py", line 114, in train_epoch
    log_likelihood = model.crf(logits, y, mask=mask, reduction='mean')
  File "/cm/local/apps/python37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/cm/local/apps/python37/lib/python3.7/site-packages/torchcrf/__init__.py", line 90, in forward
    self._validate(emissions, tags=tags, mask=mask)
  File "/cm/local/apps/python37/lib/python3.7/site-packages/torchcrf/__init__.py", line 162, in _validate
    'the first two dimensions of emissions and mask must match, '
ValueError: the first two dimensions of emissions and mask must match, got (16, 180) and (16, 180, 50)

As you can see, the first two dimensions actually DO MATCH. 16 being the batch_size and 180 being the n_ctx attribute from the config file. Not sure, what the 50 represents here. Additionally, I have also checked the shapes of the arguments passed to the model.crf() function. It gives me:

logits shape: torch.Size([16, 180, 14])
mask shape: torch.Size([16, 180, 50])
y shape: torch.Size([16, 180])

This is the config file I have used:

{
    "emb_class": "glove",
    "enc_class": "bilstm",
    "n_ctx": 180,
    "lowercase": true,
    "token_emb_dim": 300,
    "pad_token": "<pad>",
    "pad_token_id": 0,
    "unk_token": "<unk>",
    "unk_token_id": 1,
    "pos_emb_dim": 100,
    "pad_pos": "<pad>",
    "pad_pos_id": 0,
    "char_n_ctx": 50,
    "char_vocab_size": 262,
    "char_padding_idx": 261,
    "char_emb_dim": 50,
    "char_num_filters": 30,
    "char_kernel_sizes": [3, 9],
    "dropout": 0.3,
    "lstm_hidden_dim": 200,
    "lstm_num_layers": 3,
    "lstm_dropout": 0.0,
    "mha_num_attentions": 8,
    "pad_label": "<pad>",
    "pad_label_id": 0,
    "default_label": "O"
}

Here are the scripts I have used to preprocess and train:

python preprocess.py \
--config=configs/config-glove.json \
--data_dir=data/bangla_glove \
--embedding_path embeddings/bn_glove.39M.300d.txt 
CUDA_LAUNCH_BLOCKING=1 \
python train.py \
--config=configs/config-glove.json \
--data_dir=data/bangla_glove \
--save_path=pytorch-model-glove-bn-ext.pt \
--batch_size=16 \
--eval_batch_size=8 \
--lr=1e-5 \
--epoch=60 \
--use_mha \
--use_crf \
--patience 5 \
--embedding_trainable

dataset.py was modified like following :

class CoNLLGloveDataset(Dataset):
    def __init__(self, config, path):
        from allennlp.modules.elmo import batch_to_ids
        pad_ids = [config['pad_token_id']] * config['char_n_ctx']
        all_token_ids = []
        # all_pos_ids = []
        all_char_ids = []
        all_label_ids = []
        with open(path,'r',encoding='utf-8') as f:
            print(f"PATH: {path}")
            for line in f:
                # print("Line: ", line)
                line = line.strip()
                items = line.split('\t')
                # print(f"items: {items}")
                token_ids = [int(d) for d in items[1].split()]
                # pos_ids   = [int(d) for d in items[2].split()]
                # using ELMo.batch_to_ids, compute character ids: ex) 'The' [259, 85, 105, 102, 260, 261, 261, ...]
                # (actually byte-based, char_vocab_size == 262, char_padding_idx == 261)
                tokens    = items[2].split()
                print(f"tokens: {tokens}")
                char_ids  = batch_to_ids([tokens])[0].detach().cpu().numpy().tolist()
                for _ in range(len(token_ids) - len(char_ids)):
                    char_ids.append(pad_ids)
                label_ids = [int(d) for d in items[0].split()]
                all_token_ids.append(token_ids)
                # all_pos_ids.append(pos_ids)
                all_char_ids.append(char_ids)
                all_label_ids.append(label_ids)
        all_token_ids = torch.tensor(all_token_ids, dtype=torch.long)
        # all_pos_ids = torch.tensor(all_pos_ids, dtype=torch.long)
        all_char_ids = torch.tensor(all_char_ids, dtype=torch.long)
        all_label_ids = torch.tensor(all_label_ids, dtype=torch.long)

        self.x = TensorDataset(all_token_ids,  all_char_ids)
        self.y = all_label_ids

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

Appreciate your help!!

dsindex commented 2 years ago

did you modify ‘model.py’?

https://github.com/dsindex/ntagger/blob/master/model/model.py#L303

you need to modify not to use ‘pos embedding’.

appledora commented 2 years ago

Yeah, I commented out the pos references and updated indices.

class GloveLSTMCRF(BaseModel):
    def __init__(self,
            config,
            embedding_path,
            label_size,
            # pos_size,
            emb_non_trainable=True,
            use_crf=False,
            use_ncrf=False,
            use_char_cnn=False,
            use_mha=False):
        super().__init__(config=config)

        self.config = config
        self.device = config['args'].device
        self.seq_size = config['n_ctx']
        # pos_emb_dim = config['pos_emb_dim']
        lstm_hidden_dim = config['lstm_hidden_dim']
        lstm_num_layers = config['lstm_num_layers']
        lstm_dropout = config['lstm_dropout']
        self.use_char_cnn = use_char_cnn
        self.use_crf = use_crf
        self.use_ncrf = use_ncrf
        self.use_mha = use_mha
        mha_num_attentions = config['mha_num_attentions']

        # glove embedding layer
        weights_matrix = super().load_embedding(embedding_path)
        vocab_dim, token_emb_dim = weights_matrix.size()
        padding_idx = config['pad_token_id']
        self.embed_token = super().create_embedding_layer(vocab_dim, token_emb_dim, weights_matrix=weights_matrix, non_trainable=emb_non_trainable, padding_idx=padding_idx)

        # pos embedding layer
        # self.pos_vocab_size = pos_size
        # padding_idx = config['pad_pos_id']
        # self.embed_pos = super().create_embedding_layer(self.pos_vocab_size, pos_emb_dim, weights_matrix=None, non_trainable=False, padding_idx=padding_idx)

        emb_dim = token_emb_dim

        # char embedding layer
        if self.use_char_cnn:
            self.charcnn = CharCNN(config)
            emb_dim = emb_dim + self.charcnn.last_dim

        # BiLSTM layer
        self.lstm = nn.LSTM(input_size=emb_dim,
                            hidden_size=lstm_hidden_dim,
                            num_layers=lstm_num_layers,
                            dropout=lstm_dropout,
                            bidirectional=True,
                            batch_first=True)
        self.lstm_dim = lstm_hidden_dim*2

        self.dropout = nn.Dropout(config['dropout'])

        # Multi-Head Attention layer
        self.mha_dim = self.lstm_dim
        if self.use_mha:
            self.mha = nn.MultiheadAttention(self.lstm_dim, num_heads=mha_num_attentions)
            self.layernorm_mha = nn.LayerNorm(self.mha_dim)

        # projection layer
        self.label_size = label_size
        self.linear = nn.Linear(self.mha_dim, self.label_size)
        if self.config['args'].use_isomax:
            self.linear = IsoMax(self.mha_dim, self.label_size)

        # CRF layer
        if self.use_crf:
            if self.use_ncrf:
                use_gpu = True
                if self.device == 'cpu': use_gpu = False
                self.crf = NCRF(self.label_size-2, use_gpu) # -2 because of START_TAG, STOP_TAG
            else:
                self.crf = CRF(num_tags=self.label_size, batch_first=True)

    def forward(self, x):
        # x[0, 1] : [batch_size, seq_size]
        # x[2]    : [batch_size, seq_size, char_n_ctx]
        token_ids = x[0]
        # pos_ids = x[1]

        mask = torch.sign(torch.abs(token_ids)).to(torch.uint8).to(self.device)
        # mask : [batch_size, seq_size]
        lengths = torch.sum(mask.to(torch.long), dim=1)
        # lengths : [batch_size]

        # 1. Embedding
        token_embed_out = self.embed_token(token_ids)
        # token_embed_out : [batch_size, seq_size, token_emb_dim]
        # pos_embed_out = self.embed_pos(pos_ids)
        # pos_embed_out : [batch_size, seq_size, pos_emb_dim]
        if self.use_char_cnn:
            char_ids = x[1]
            # char_ids : [batch_size, seq_size, char_n_ctx]
            charcnn_out = self.charcnn(char_ids)
            # charcnn_out : [batch_size, seq_size, self.charcnn.last_dim]
            embed_out = torch.cat([token_embed_out,  charcnn_out], dim=-1)
            # embed_out : [batch_size, seq_size, emb_dim]
        else:
            embed_out = torch.cat([token_embed_out, ], dim=-1)
            # embed_out : [batch_size, seq_size, emb_dim]
        embed_out = self.dropout(embed_out)

        # 2. LSTM
        # FIXME : pytorch 1.7.0 bug https://github.com/pytorch/pytorch/issues/43227 , lengths.cpu()
        packed_embed_out = torch.nn.utils.rnn.pack_padded_sequence(embed_out, lengths.cpu(), batch_first=True, enforce_sorted=False)
        lstm_out, (h_n, c_n) = self.lstm(packed_embed_out)
        lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True, total_length=self.seq_size)
        # lstm_out : [batch_size, seq_size, self.lstm_dim == lstm_hidden_dim*2]
        lstm_out = self.dropout(lstm_out)

        # 3. MHA
        if self.use_mha:
            # reference : https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
            #             https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390/3
            query = lstm_out.permute(1, 0, 2)
            # query : [seq_size, batch_size, self.lstm_dim]
            key = query
            value = query
            key_padding_mask = mask.ne(1) # attention_mask => mask = [[1, 1, ..., 0, ...]] => [[False, False, ..., True, ...]]
            attn_output, attn_output_weights = self.mha(query, key, value, key_padding_mask=key_padding_mask)
            # attn_output : [seq_size, batch_size, self.mha_dim]
            mha_out = attn_output.permute(1, 0, 2)
            # mha_out : [batch_size, seq_size, self.mha_dim]
            # residual, layernorm, dropout
            mha_out = self.layernorm_mha(mha_out + lstm_out)
            mha_out = self.dropout(mha_out)
        else:
            mha_out = lstm_out
            # mha_out : [batch_size, seq_size, self.mha_dim]

        # 4. Output
        logits = self.linear(mha_out)
        # logits : [batch_size, seq_size, label_size]
        if not self.use_crf: return logits
        if self.use_ncrf:
            scores, prediction = self.crf._viterbi_decode(logits, mask.bool())
        else:
            prediction = self.crf.decode(logits)
            prediction = torch.as_tensor(prediction, dtype=torch.long)
        # prediction : [batch_size, seq_size]
        return logits, prediction
dsindex commented 2 years ago

i got it. in train.py, ‘mask’ uses ‘x[1]’ which is regarded as ‘pos ids’ implicitly.

(1,3,6,1,4,…,0,0,0)

https://github.com/dsindex/ntagger/blob/master/train.py#L91 mask = torch.sign(torch.abs(x[1])).to(torch.uint8)

but you removed ‘pos ids’ from dataset.

i think you can use x[0](token ids) instead.

appledora commented 2 years ago

Modified the indexing in training, and the training seems to have started :smile: Thanks a bunch. Feel free to close the issue!!