ZhixiuYe / HSCRF-pytorch

ACL 2018: Hybrid semi-Markov CRF for Neural Sequence Labeling (http://aclweb.org/anthology/P18-2038)
305 stars 68 forks source link

How to use loss from HSCRF? #9

Closed sebastianGehrmann closed 6 years ago

sebastianGehrmann commented 6 years ago

Hey,

When I am running a forward pass with your HSCRF module, I am getting a Loss formatted like below. In your training code, you use it like this: epoch_loss += utils.to_scalar(loss) (here: https://github.com/ZhixiuYe/HSCRF-pytorch/blob/master/train.py#L177).

What exactly does this do? Why is the loss not directly a scalar?

Loss

tensor([  -40.1046, -1039.7493, -2039.6393, -1039.3293, -1038.7677,
        -2038.6620, -2038.5282, -2038.3925, -2038.2518, -2038.1088,
        -2037.9637, -2037.8209, -2037.6803, -2037.5381, -2037.3993,
        -2037.2581, -2037.1168, -2036.9700, -2036.8280, -2036.6908,
        -2036.5507, -2036.4150, -2036.2721, -2036.1278, -2035.9827,
        -2035.8431, -2035.7041, -2035.5673, -2035.4296, -2035.2881,
        -2035.1510, -2035.0101, -2034.8658, -2034.7272, -2034.5862,
        -2034.4449, -2034.3057, -2034.1682, -2034.0237, -1033.9022,
        -1040.0959, -2040.0959, -2039.9283, -2039.7893, -2039.6503,
        -2039.5123, -2039.3701, -2039.2305, -2039.0912, -2038.9553,
        -2038.8187, -2038.6780, -2038.5389, -2038.3978,   -36.7899,
          -40.0815, -1040.0814, -2039.9296, -2039.7889, -2039.6522,
        -2039.5165, -2039.3783, -2039.2343, -2039.0922, -2038.9550,
        -2038.8119, -1038.5255, -1038.3424,   -37.4315, -1040.0959,
        -2040.0959, -2039.9336, -2039.7943, -1039.5092, -1039.1903,
        -2039.0892, -2038.9465, -2038.8047, -2038.6637, -2038.5248,
        -2038.3876, -2038.2494, -1037.9631, -1037.5199, -2037.4154,
        -2037.2767, -2037.1420, -2037.0059, -2036.8678, -2036.7299,
        -2036.5885, -2036.4471, -2036.3021, -2036.1615, -2036.0237,
        -2035.8873, -2035.7494, -2035.6093, -2035.4706, -1035.3413,
          -40.0959, -1038.7875, -2038.6840, -2038.5472, -2038.4045,
        -2038.2649, -2038.1256, -2037.9882, -2037.8506, -2037.7131,
        -2037.5725, -2037.4362, -2037.2994, -2037.1644, -2037.0319,
        -2036.8983, -2036.7588, -2036.6194, -2036.4772, -2036.3358,
        -1036.2131,   -40.0959, -1039.0562, -2038.9530, -2038.8147,
        -2038.6696, -2038.5270, -2038.3877, -2038.2422, -2038.1002,
        -2037.9564, -2037.8147, -1037.6818,   -40.1046, -1039.7607,
        -2039.6508, -2039.5139, -2039.3763, -2039.2395, -2039.1012,
        -2038.9645, -2038.8303, -2038.6946, -2038.5599, -2038.4235,
        -1038.1422, -1037.9800,   -37.7956,   -40.1046, -1039.7583,
        -1039.5063, -1039.3461, -2039.2405, -2039.1073, -2038.9686,
        -2038.8303, -2038.6901, -2038.5526, -2038.4150, -2038.2737,
        -2038.1295, -2037.9882, -1037.8634, -1040.0959, -2040.0959,
        -2039.9303,   -40.0959, -1038.6449, -2038.5428, -2038.4020,
        -2038.2581, -2038.1158, -2037.9739, -2037.8346, -2037.6959,
        -2037.5570, -2037.4202, -2037.2764, -2037.1334, -2036.9911,
        -2036.8450, -1036.7113,   -40.1044, -1039.9014, -2039.7880,
          -40.0959, -1038.3672, -2038.2637, -2038.1252, -2037.9850,
        -2037.8483, -2037.7079, -2037.5686, -2037.4323, -2037.2991,
        -2037.1674, -2037.0237, -2036.8802, -2036.7416, -2036.6041,
        -2036.4674, -2036.3307, -2036.1898, -2036.0443, -2035.9014,
        -2035.7561, -2035.6185, -2035.4778, -1035.3492], device='cuda:0')
ZhixiuYe commented 6 years ago
  1. When I print loss after line 175, I get below information

    Variable containing:
    92.9371
    [torch.cuda.FloatTensor of size 1 (GPU 0)]

    I guess you are using torch 0.4.0, but this code is written in 0.2.0. You can install the correct pytorch version and try again.

  2. to_scalar function is just for recording the current loss and print out in https://github.com/ZhixiuYe/HSCRF-pytorch/blob/master/train.py#L199. And it's ok to delete line 177.

sebastianGehrmann commented 6 years ago

Unfortunately, the rest of the code I am using is pytorch 0.4.0, so I can't mix and match and need to port the SCRF code. For some guidance, could you print some more sizes so I can see what I need to fix?

My tag size is 4 since I am not doing NER, and am trying to only get binary labels, so my tags are (no tag, tag, start, end). Using test data with batch size 10, I am getting the following sizes:

 def forward(self, feats, mask_word, tags, mask_tag):
        self.batch_size = feats.size(0)
        self.sent_len = feats.size(1)
        # feats: (10 x 48 x 256)
        # mask_words (10)
        # tags: (10, 40, 4)
        # mask_tag: (10, 40)
        feats = self.dense(feats)
        self.SCRF_scores = self.HSCRF_scores(feats)
        # self.SCRF_scores: (10, 48, 48, 4, 4)
        forward_score = self.get_logloss_denominator(self.SCRF_scores, mask_word)
        # forward_score: (1)
        numerator = self.get_logloss_numerator(tags, self.SCRF_scores, mask_tag)
        # numerator: (209)
        loss =  (forward_score - numerator.sum()) / self.batch_size
        # loss: (209)
        return loss

Here are the two functions annotated:

    def get_logloss_numerator(self, goldfactors, scores, mask):
        # mask: (10, 40)
        batch_size = scores.size(0) # 10
        sent_len = scores.size(1) # 48
        tagset_size = scores.size(3) # 4
        goldfactors = goldfactors[:, :, 0]*sent_len*tagset_size*tagset_size + goldfactors[:,:,1]*tagset_size*tagset_size+goldfactors[:,:,2]*tagset_size+goldfactors[:,:,3]
        # goldfactors: (10, 40)
        factorexprs = scores.view(batch_size, -1)
        # factorexprs: (10, 36864)
        val = torch.gather(factorexprs, 1, goldfactors)
        # val: (10, 40)
        numerator = val.masked_select(mask)
        # numerator: (209)
        return numerator
    def get_logloss_denominator(self, scores, mask):
        logalpha = Variable(torch.FloatTensor(self.batch_size, self.sent_len+1, self.tagset_size).fill_(-10000.)).cuda()
        # logalpha: (10, 49, 4)
        logalpha[:, 0, self.start_id] = 0.
        istarts = [0] * self.ALLOWED_SPANLEN + range(self.sent_len - self.ALLOWED_SPANLEN+1)
        # len(istarts): 49
        for i in range(1, self.sent_len+1):
                tmp = scores[:, istarts[i]:i, i-1] + \
                        logalpha[:, istarts[i]:i].unsqueeze(3).expand(self.batch_size, i - istarts[i], self.tagset_size, self.tagset_size)
                tmp = tmp.transpose(1, 3).contiguous().view(self.batch_size, self.tagset_size, (i-istarts[i])*self.tagset_size)
                max_tmp, _ = torch.max(tmp, dim=2)
                tmp = tmp - max_tmp.view(self.batch_size, self.tagset_size, 1)
                logalpha[:, i] = max_tmp + torch.log(torch.sum(torch.exp(tmp), dim=2))

        mask = mask.unsqueeze(1).unsqueeze(1).expand(self.batch_size, 1, self.tagset_size)
        # mask: (10,1,4)
        alpha = torch.gather(logalpha, 1, mask).squeeze(1)
        # alpha: (10,4)
        return alpha[:,self.stop_id].sum() # return: (1)

=======================><========================= Edit: As it turns out, I summed the wrong tensor - sizes are all correct. I am now getting a ton of leaf variable has been moved into the graph interior errors, due to the indexing and overwriting in values in these functions. Did you encounter these errors when you built the model? How did you address this?

ZhixiuYe commented 6 years ago
  1. In following code, obviously, loss should be size of 1.

        forward_score = self.get_logloss_denominator(self.SCRF_scores, mask_word)
        # forward_score: (1)
        numerator = self.get_logloss_numerator(tags, self.SCRF_scores, mask_tag)
        # numerator: (209)
        loss =  (forward_score - numerator.sum()) / self.batch_size
        # loss: (209)
  2. leaf variable has been moved into the graph interior. I guess it'is because that in pytorch 0.4.0, the class Variable has been removed and replaced by tensor. But I'm not very familiar with pytorch 0.4.0 that I don;t know the details.

sebastianGehrmann commented 6 years ago

I managed to refactor this to torch.cat operations so the error is resolved. I now run into a problem that I can't quite understand from your code - your HSCRF_scores functions only computes the likelihoods for positive labels, but keeps O/start/end at -1e5 (by setting it to values in the m30000 tensor). Where in your SCRF code do you actually compute the probability that a tag is O?

ZhixiuYe commented 6 years ago

First of all, you can refer to this paper Semi-Markov Conditional Random Fields for Information Extraction for some details about semi-Markov CRFs. Actually, HSCRF_scores is to computes scores and the shape of scores is (self.batch_size, self.sent_len, self.sent_len, self.tagset_size, self.tagset_size), which is corresponding to gk(j, x, s) in that paper instead of likelihoods.

sebastianGehrmann commented 6 years ago

Thanks for the link to the paper. It might be helpful to annotate your code with the corresponding equations to help code understanding. I still don't get why O is never scored. Eq(2) in your linked paper defines g^k in terms of y_j and y_{j-1}, but the code is only scoring the different tags.

ZhixiuYe commented 6 years ago

I get! This line if span == 0: , I calculate the score of O, and I assume that the socre of O can be calculated only when its length is one, and when its length is more than one, we don't calculate its score.

sebastianGehrmann commented 6 years ago

I see - but even when I only print the result of the code for span length 0,

 tmp = torch.cat((self.transition[:, :validtag_size].unsqueeze(0).unsqueeze(0) + emb_x[:, 0, :, :validtag_size].unsqueeze(2),
                                 m10000,
                                 self.transition[:, -2:].unsqueeze(0).unsqueeze(0) + emb_x[:, 0, :, -2:].unsqueeze(2)), 3)
                scores[:, diag0, diag0] = tmp

every entry looks like this:

 [[ 6.1834e-01, -1.0000e+04, -4.2706e-01,  2.8736e-01],
  [-4.2289e-01, -1.0000e+04, -4.3145e-02, -1.0890e+00],
  [-5.2040e-01, -1.0000e+04, -3.2427e-01, -1.1558e+00],
  [-6.0971e-01, -1.0000e+04,  2.9183e-01,  5.1828e-01]],

I only have one tag, so the first entry is for that tag, the one for O is all not calculated at all, and the last two are START and STOP.

ZhixiuYe commented 6 years ago

I add some annotations: https://github.com/ZhixiuYe/HSCRF-pytorch/commit/c7142f208d617934cb707e484c96e911b433d67f#diff-e90865298a808f704cff7317a658876e These four entries are a tag(like PER), STOP, START and O respectively.