IamAdiSri / hf-trim

Reduce the size of pretrained Hugging Face models via vocabulary trimming.
Mozilla Public License 2.0
43 stars 5 forks source link

Mbart embedding layer pruning #3

Closed BakingBrains closed 1 year ago

BakingBrains commented 1 year ago

Thanks for the great Repo.

In the mentioned repo https://github.com/hyunwoongko/asian-bart He is doing Mbart embedding layer pruning. I want to do the same for a particular language. Any suggestions where I should look into?

Thanks and Regards.

IamAdiSri commented 1 year ago

Yes, you can use this package for this purpose. I've added some code below so you can try it out.

from transformers import MBartConfig, MBartTokenizer, MBartForConditionalGeneration
from hftrim.TokenizerTrimmer import TokenizerTrimmer
from hftrim.ModelTrimmers import MBartTrimmer

# the list below needs to be replaced with the training and target language corpus for your use-case
# I've had decent success with 10k training and target samples but the more the better
data = [
        " UN Chief Says There Is No Military Solution in Syria", # source language
        "Şeful ONU declară că nu există o soluţie militară în Siria" # target language
]

# load pretrained config, tokenizer and model
config = MBartConfig.from_pretrained("facebook/mbart-large-en-ro")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")

# trim tokenizer
tt = TokenizerTrimmer(tokenizer)
tt.make_vocab(data)
tt.make_tokenizer()

# trim model
mt = MBartTrimmer(model, config, tt.trimmed_tokenizer)
mt.make_weights(tt.trimmed_vocab_ids)
mt.make_model()

Let me know if you run into any issues and I'll try to help. :)

BakingBrains commented 1 year ago

@IamAdiSri Thank you for the suggestion. I did the same thing, but after pruning his model size got reduced to 387 mb for English. Below is https://github.com/hyunwoongko/asian-bart model as mentioned in repo.

English model
vocab size: 32k
model size: 387M
languages: English (en_XX)
architecture: Transformer 12 Encoder + 12 Decoder
name: hyunwoongko/asian-bart-en

I want to do similar for an Hindi language. Once after pruning, the model about 1.2GB (even I have around 32k vocab), am I missing something here?

Thanks and Regards

IamAdiSri commented 1 year ago

Hindi would definitely result in a bigger model than English since there are more characters than English and thus more number of tokens. Also each character in Hindi has its own symbol for matras and half characters. I think 1.2GB sounds fine.

Btw, are you trying to create a model that can translate or are are you trying to use the MBart only as an text encoder? If you're trying to translate between English and Hindi, then too you'll have a bigger model since the tokens for both languages need to be accounted for.

IamAdiSri commented 1 year ago

One more thing, hyunwoongko's repository doesn't mention which MBart instance they started with and pruned down. So they might be pruning a smaller MBart version to begin with, which gives them a smaller model.

BakingBrains commented 1 year ago

@IamAdiSri Thanks a lot for clarifying. yeah, I will try with different MBART variants. I was planning to use it as a decoder with input from vision encoder.

This is the decoder. One more problem is I want to add special tokens, so if I add special tokens and try to load it's throwing dimension error, whereas with models from hyunwoongko's repository it loads fine.

Here is the code which throws error.

from transformers import MBartConfig, MBartForCausalLM, XLMRobertaTokenizer, AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
from typing import Any, List, Optional, Union
from transformers.file_utils import ModelOutput

class BARTDecoder(nn.Module):

    def __init__(
        self, decoder_layer: int, max_position_embeddings: int, name_or_path: Union[str, bytes, os.PathLike] = None
    ):
        super().__init__()
        self.decoder_layer = decoder_layer
        self.max_position_embeddings = max_position_embeddings

        self.tokenizer = XLMRobertaTokenizer.from_pretrained(
            "prunedMbart" if not name_or_path else name_or_path
        )
        # bad_words_ids = [self.tokenizer(bad_word).input_ids for bad_word in ["entire", "save"]]
        bad_words_ids = self.tokenizer(["entire", "save"], add_special_tokens=False).input_ids

        self.model = MBartForCausalLM(
            config=MBartConfig(
                is_decoder=True,
                is_encoder_decoder=False,
                add_cross_attention=True,
                decoder_layers=self.decoder_layer,
                max_position_embeddings=self.max_position_embeddings,
                vocab_size=len(self.tokenizer),
                scale_embedding=True,
                add_final_layer_norm=True,
            )
        )
        self.model.forward = self.forward  #  to get cross attentions and utilize `generate` function

        self.model.config.is_encoder_decoder = True  # to get cross-attention
        self.add_special_tokens(["<sep/>"]) # <sep/> is used for representing a list in a JSON
        # self.add_special_tokens(["<junk/>"])
        self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
        self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference

        # del self.tokenizer.vocab['▁man']

        if not name_or_path:
            bart_state_dict = MBartForCausalLM.from_pretrained("prunedMbart").state_dict()
            new_bart_state_dict = self.model.state_dict()
            for x in new_bart_state_dict:
                if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024:
                    new_bart_state_dict[x] = torch.nn.Parameter(
                        self.resize_bart_abs_pos_emb(
                            bart_state_dict[x],
                            self.max_position_embeddings
                            + 2,  # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
                        )
                    )
                elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):
                    new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :]
                else:
                    new_bart_state_dict[x] = bart_state_dict[x]
            self.model.load_state_dict(new_bart_state_dict)

    def add_special_tokens(self, list_of_tokens: List[str]):

        newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
        if newly_added_num > 0:
            self.model.resize_token_embeddings(len(self.tokenizer))

    def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):

        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
        if past is not None:
            input_ids = input_ids[:, -1:]
        output = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past,
            "use_cache": use_cache,
            "encoder_hidden_states": encoder_outputs.last_hidden_state,
        }
        return output

    def forward(
        self,
        input_ids,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        past_key_values: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: bool = None,
        output_attentions: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[torch.Tensor] = None,
        return_dict: bool = None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict
        outputs = self.model.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = self.model.lm_head(outputs[0])

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return ModelOutput(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            decoder_attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    @staticmethod
    def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
        """
        Resize position embeddings
        Truncate if sequence length of Bart backbone is greater than given max_length,
        else interpolate to max_length
        """
        if weight.shape[0] > max_length:
            weight = weight[:max_length, ...]
        else:
            weight = (
                F.interpolate(
                    weight.permute(1, 0).unsqueeze(0),
                    size=max_length,
                    mode="linear",
                    align_corners=False,
                )
                .squeeze(0)
                .permute(1, 0)
            )
        return weight

if __name__ == "__main__":
    maxlength = 1536
    decoderlayer = 4
    nameorpath = ''

    decoder = BARTDecoder(max_position_embeddings=maxlength,
                          decoder_layer=decoderlayer,
                          name_or_path=nameorpath)

if I change the model with hyunwoongko/asian-bart-ecjk it loads with special token.

Any suggestions here?

Thanks and regards.

IamAdiSri commented 1 year ago

Can you post the error as well?

IamAdiSri commented 1 year ago

Can you also post the exact code you're using to trim the model?

BakingBrains commented 1 year ago

@IamAdiSri Thanks for following. Here the link to my colab https://colab.research.google.com/drive/1NknxhhATIW7Y_-jjzzHJCzYYJFDQDGfC?usp=share_link I tried with multiple combination. please do check.

The error is

RuntimeError: Error(s) in loading state_dict for MBartForCausalLM:
    size mismatch for model.decoder.embed_tokens.weight: copying a param with shape torch.Size([101, 1024]) from checkpoint, the shape in current model is torch.Size([102, 1024]).
    size mismatch for lm_head.weight: copying a param with shape torch.Size([101, 1024]) from checkpoint, the shape in current model is torch.Size([102, 1024]).
IamAdiSri commented 1 year ago

I wasn't able to go through all of the code, but I've changed certain parts of your code and rewritten it below. The decoder is working for me and producing output, so most probably the problem lies in one of your custom functions.

I saw that you're using XLMRoberta tokenizer which i don't think is compatible with the MBart so i switched that to MBart50Tokenizer. Additionally, to reduce risk of errors later on, I load the model as MBartForCausalLM and then prune it (instead of pruning MBartForConditionalGeneration) since you're using the causal LM model later.

You can run the cells below in this order. I tested it on Colab.

from transformers import MBartConfig, MBart50Tokenizer, MBartForCausalLM
from hftrim.TokenizerTrimmer import TokenizerTrimmer
from hftrim.ModelTrimmers import MBartTrimmer

data = [
        "भाषाओं में उपलब्ध विकिपीडिया का सबसे बड़ा संस्करण है और सभी संस्करणों में पचपनवाँ है। और इसे मुख्यतः",
        "हिन्दीभाषी लोगों की आवश्यकताओं की पूर्ति के लिए बनाया गया था। चूँकि हिन्दी विकिपीडिया इण्डिक स्क्रिप्ट (देवनागरी)",
        "का प्रयोग करता है इसलिए इसमें जटिल पाठ प्रतिपादन सहायक की आवश्यकता पड़ती है। विकिपीडिया पर ध्वन्यात्मक रोमन",
        "वर्णमाला परिवर्तक उपलब्ध है, इसलिए बिना किसी विशेष हिन्दी टाइपिंग सॉफ्टवेर डाउनलोड किये रोमन कुंजीपटल का उपयोग",
        "देवनागरी में टंकण करने के लिए किया जा सकता है।",
]

# load pretrained config, tokenizer and model
config = MBartConfig.from_pretrained("facebook/mbart-large-50")
tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50")
model = MBartForCausalLM.from_pretrained("facebook/mbart-large-50")

# trim tokenizer
trimTokenizer = TokenizerTrimmer(tokenizer)
trimTokenizer.make_vocab(data)
trimTokenizer.make_tokenizer()

# trim model
trimModel = MBartTrimmer(model, config, trimTokenizer.trimmed_tokenizer)
trimModel.make_weights(trimTokenizer.trimmed_vocab_ids)
trimModel.make_model()

# save with
trimTokenizer.trimmed_tokenizer.save_pretrained('trimbart')
trimModel.trimmed_model.save_pretrained('trimbart')
# restart runtime before running this cell
from transformers import MBartConfig, MBart50Tokenizer, MBartForCausalLM

config = MBartConfig.from_pretrained("trimbart")
tokenizer = MBart50Tokenizer.from_pretrained("trimbart")
model = MBartForCausalLM.from_pretrained("trimbart")
# add your <sep/> token to the tokenizer then add an embedding for it to the model
tokenizer.add_special_tokens({'additional_special_tokens': tokenizer.special_tokens_map['additional_special_tokens']+['<sep/>']})
model.resize_token_embeddings(len(tokenizer))
# isolate the decoder from the model
decoder = model.model.decoder
# mbart requires setting a source and target language
tokenizer.src_lang = 'hi_IN'
tokenizer.tgt_lang = 'hi_IN'

inp = tokenizer("भाषाओं में उपलब्ध विकिपीडिया का सबसे बड़ा संस्करण है और सभी संस्करणों में पचपनवाँ है। और इसे मुख्यतः <sep/>", return_tensors='pt')
out = decoder(**inp)
print(out)
BakingBrains commented 1 year ago

@IamAdiSri Thanks a lot, I did the same, but was wrong in the part of adding separate token to tokenizer.

Thank you.😄

IamAdiSri commented 1 year ago

If you get this working I'd love to see the results. Lmk how it goes, or if you run into other problems.