octanove / shiba

Pytorch implementation and pre-trained Japanese model for CANINE, the efficient character-level transformer.
Other
90 stars 14 forks source link

Problem/question with `random_mask` in `masking.py` #12

Open sven-nm opened 3 months ago

sven-nm commented 3 months ago

Hey guys thanks for this awesome adaptation of CANINE 😊 I've been working on adapting for any language and I came across weird empty masks. I think the problem is in training/masking.py in the function random_mask. We have the following (starting at line 42):

    special_tokens_mask = _special_tokens_mask_from_range(input_ids, range(MIN_SPECIAL_TOKEN, MAX_SPECIAL_TOKEN))
    # ⚠️ So far okey

    special_tokens_mask = special_tokens_mask | attention_mask.bool()
    # ⚠️ Here we have a problem: say attention mask is a `torch.ones()` of (shape batch_size, seq_len) (e.g. 8,2048), then we are going to end up with `special_tokens_mask` as a `torch.ones()` of the same shape. Which we use later on (see below)

    mask_count = math.floor(input_ids.shape[1] * masking_percent)

    indices_to_mask = []
    for unmaskable_indices, inputs in zip(special_tokens_mask, input_ids):
        # compute the possible indices we could mask for this input
        maskable_indices = torch.arange(inputs.shape[0]).masked_select(~unmaskable_indices)
        # ⚠️ Here we are using `~unmaskable_indices`, so basically no input id can be masked ! 

        # take mask_count random indices, get the maskable indices using random indices
        indices_to_mask.append(maskable_indices[torch.randperm(maskable_indices.shape[0])[:mask_count]])

I've added my comments with a ⚠️. Am I missing something here ? My hunch is that line 43 should be special_tokens_mask = special_tokens_mask | ~attention_mask.bool().

Mindful commented 2 months ago

@sven-nm

Hi there, I apologize for the delayed response - I totally missed the notification for this issue.

Unfortunately it has been a long time since I wrote this code, so I can't answer very confidently, but it looks like you're right. That said, I double checked and this logic is used in all three masking methods (including the one we trained with) and if our masking was genuinely completely broken, I'm not sure how the model would have learned anything. Given that, I am... not sure what to tell you, to be honest.

Edit: If you come to a conclusion you are confident about and fix something please feel free to open a PR though.

ganeshkrishnan1 commented 6 days ago

@sven-nm does this affect training? We are planning to start training on our English corpus and I was wondering if this bug has a major effect on the training loss

Mindful commented 6 days ago

@ganeshkrishnan1 If this is actually broken it should pretty much break training entirely (which may or may not be reflected in the loss), but it should be pretty easy to figure out if there's a problem or not, and if it is an issue it would be very easy to fix - it would just be the change suggested in the original issue.

ganeshkrishnan1 commented 6 days ago

I haven't started the training for this yet. But, how do you figure out if there is a problem if it's not evident with the loss? Do you mean run it on live world classification (or similar task) scenarios?

Mindful commented 6 days ago

@ganeshkrishnan1 Easiest ways are either to step through the training code with a debugger and just look at what is actually being masked, or if you prefer, train on an extremely small simplified dataset (like strings of consecutive letters or something) and see if the model can learn that properly.

Mindful commented 3 hours ago

@ganeshkrishnan1 Btw, if you do figure this out one way or the other please let me know. I can update the code in the main branch (or you can open a PR) if changes are necessary.

ganeshkrishnan1 commented 18 minutes ago

Our team is going to start testing work on this early next week. I will keep you updated