kaistAI / LangBridge

[ACL 2024] LangBridge: Multilingual Reasoning Without Multilingual Supervision
https://aclanthology.org/2024.acl-long.405/
71 stars 7 forks source link

LangBridge on "facebook/nllb-200-distilled-600M"(other than T5 architecture) #17

Open Kosei1227 opened 3 days ago

Kosei1227 commented 3 days ago

Hi, based on my understanding, we can extend the LangBridge approach to the seq2seq models which have {model_name}EncoderModel in HuggingFace. However, how about seq2seq models which only have general models such as m2m100 and nllb-200?

I implemented the NLLBModeling from scratch.

import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_outputs import BaseModelOutput

logger = logging.get_logger(__name__)

# Configuration class for NLLB
class NLLBConfig(PretrainedConfig):
    model_type = "nllb"

    def __init__(
        self,
        vocab_size=256204,  # Vocabulary size of NLLB
        d_model=1024,       # Dimension of the embeddings and hidden states
        num_encoder_layers=24,  # Number of encoder layers
        num_decoder_layers=24,  # Number of decoder layers
        num_heads=16,       # Number of attention heads
        d_ff=4096,          # Dimension of the feedforward network
        dropout_rate=0.1,   # Dropout rate
        attention_dropout_rate=0.1,  # Dropout rate for attention weights
        activation_function="relu",  # Activation function
        max_position_embeddings=1024,  # Maximum sequence length
        layer_norm_eps=1e-6,  # Epsilon for layer normalization
        is_encoder_decoder=False,  # Indicates if it's an encoder-decoder model
        use_cache=False,      # Whether to use cache during generation
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate
        self.activation_function = activation_function
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.is_encoder_decoder = is_encoder_decoder
        self.use_cache = use_cache

# Layer Normalization class
class NLLBLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        variance = ((x - mean) ** 2).mean(-1, keepdim=True)
        normalized = (x - mean) / torch.sqrt(variance + self.eps)
        return self.weight * normalized

# Feedforward network
class NLLBFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config.d_model, config.encoder_ffn_dim)
        self.activation = nn.ReLU() if config.activation_function == "relu" else nn.GELU()
        self.dense_2 = nn.Linear(config.encoder_ffn_dim, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense_2(x)
        return x

class NLLBAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.d_model
        self.num_heads = config.encoder_attention_heads

        self.dropout = config.attention_dropout
        self.is_decoder = False  # Set to False as it's not a decoder

        self.head_dim = self.embed_dim // self.num_heads
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
        batch_size, seq_length, embed_dim = hidden_states.size()

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        query = query * self.scaling

        attn_weights = torch.matmul(query, key.transpose(-2, -1))

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout_layer(attn_weights)

        attn_output = torch.matmul(attn_weights, value)

        # Fix the typo here
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embed_dim)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output,)
        if output_attentions:
            outputs += (attn_weights,)
        return outputs

# Transformer Encoder Layer
class NLLBEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Attention layer with config
        self.self_attn = NLLBAttention(config)
        self.norm1 = NLLBLayerNorm(config.d_model, eps=getattr(config, 'layer_norm_eps', 1e-5))
        self.dropout1 = nn.Dropout(config.dropout)

        self.feed_forward = NLLBFeedForward(config)
        self.norm2 = NLLBLayerNorm(config.d_model, eps=getattr(config, 'layer_norm_eps', 1e-5))
        self.dropout2 = nn.Dropout(config.dropout)

    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
        # Self Attention
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        attn_outputs = self.self_attn(hidden_states, attention_mask, output_attentions=output_attentions)
        attn_output = attn_outputs[0]
        hidden_states = residual + self.dropout1(attn_output)

        # Feed Forward
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        feed_forward_output = self.feed_forward(hidden_states)
        hidden_states = residual + self.dropout2(feed_forward_output)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_outputs[1],)
        return outputs

# Positional Embeddings
class NLLBPositionalEmbedding(nn.Module):
    def __init__(self, max_position_embeddings, embedding_dim):
        super().__init__()
        self.weight = nn.Embedding(max_position_embeddings, embedding_dim)

    def forward(self, position_ids):
        return self.weight(position_ids)

# Transformer Encoder
class NLLBEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout)

        self.layernorm_embedding = NLLBLayerNorm(config.d_model, eps=1e-6)

        self.config = config  # Store the config object
        # Initialize other components using config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.embed_positions = nn.Embedding(config.max_position_embeddings, config.d_model)
        self.layers = nn.ModuleList([NLLBEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layer_norm = nn.LayerNorm(config.d_model)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        if input_ids is None:
            raise ValueError("You have to specify input_ids")

        input_ids = torch.clamp(input_ids, min=0, max=self.config.vocab_size - 1)
        # Debugging: Check input_ids
        if input_ids.min() < 0 or input_ids.max() >= self.config.vocab_size:
            raise ValueError(
                f"input_ids contain values outside the valid range [0, {self.config.vocab_size - 1}]. "
                f"Found min value {input_ids.min()} and max value {input_ids.max()}."
            )

        # print(f"input_ids shape: {input_ids.shape}")
        # print(f"input_ids min: {input_ids.min().item()}, max: {input_ids.max().item()}")
        # print(f"vocab_size: {self.config.vocab_size}")

        input_shape = input_ids.size()
        device = input_ids.device

        # Attention mask
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)

        # Compute the attention mask
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # Embeddings
        inputs_embeds = self.embed_tokens(input_ids)
        seq_length = input_shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape)

        position_embeds = self.embed_positions(position_ids)
        hidden_states = inputs_embeds + position_embeds
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Encoder layers
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        for layer_module in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = layer_module(
                hidden_states,
                attention_mask=extended_attention_mask,
                output_attentions=output_attentions,
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions += (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )

    def get_extended_attention_mask(self, attention_mask, input_shape, device):
        # We create a 4D attention mask from a 2D tensor mask.
        # [batch_size, 1, 1, seq_length]
        extended_attention_mask = attention_mask[:, None, None, :]

        # Convert attention mask to float
        extended_attention_mask = extended_attention_mask.to(dtype=self.embed_tokens.weight.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -1e9  # Masked positions will have -1e9
        return extended_attention_mask

# Pretrained Model class
class NLLBPreTrainedModel(PreTrainedModel):
    config_class = NLLBConfig
    base_model_prefix = "nllb_model"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        # Initialize the weights
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
        if isinstance(module, NLLBLayerNorm):
            module.weight.data.fill_(1.0)

# NLLB Encoder Model
class NLLBEncoderModel(NLLBPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.encoder = NLLBEncoder(config)

        # Initialize weights
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return encoder_outputs

        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def get_input_embeddings(self):
        return self.encoder.embed_tokens

    def set_input_embeddings(self, value):
        self.encoder.embed_tokens = value

    def _tie_weights(self):
        # NLLB does not require weight tying
        pass

The results after LangBridge on NLLB-200 were terrible. It doesn't even output readable sentences. Question: A basket contains 25 oranges among which 1 is bad, 20% are unripe, 2 are sour and the rest are good. How many oranges are good? Response: paa family צריך ukatsti family צריךတို့သည်ikisaant family צריךတို့သည်တို့သည်တို့သည်яў the family צריךတို့သည်တို့သည်တို့သည် т family צריךတို့သည်يىن צריךတို့သည်တို့သည် su family צריךတို့သည်တို့သည်яў ukatsti d valor ukatsti family צריךတို့သည်ikisaant family צריךတို့သည်တို့သည်яў the family צריךတို့သည်တို့သည် т family צריךတို့သည်يىن צריךတို့သည်တို့သည် su family צריךတို့သည်яў ukatsti dण्ड ukatsti family צריךတို့သည်ikisaant family צריךတို့သည်яў the family צריךတို့သည် т family צריךတို့သည်يىن צריךတို့သည်တို့သည် su family צריך tout ukatsti d 지원 ukatsti family צריךတို့သည်ikisaant family צריך tout the family צריך т family צריךတို့သည်يىن צריךတို့သည်တို့သည် su familyတို့သည်яў ukatsti d dद्धन् hokன்00 family צריךတို့သည်တို့သည်တို့သည် А family צריךတို့သည်တို့သည် А family צריךတို့သည် А familyတို့သည် su family צריך צריך צריךတို့သည်яў ئ kol ukatsti dरे family צריך צריך צריךတို့သည် dக்க Fran thepaa family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family צריך צריך צריךတို့သည်يىن צריך צריך צריךတို့သည် family

Have authors ever tried to implement LangBridge on other models? Also, I want to know by implementations we can do LangBridge even with the following short codes.

from transformers import M2M100Config, M2M100Model
from torch import nn

class NLLBEncoderModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = M2M100Model(config)
        self.encoder = self.model.get_encoder()  # Extract the encoder from M2M100 (similar to NLLB-200)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        return self.encoder(input_ids, attention_mask=attention_mask, **kwargs)

    def get_input_embeddings(self):
        # Access the input embeddings from the encoder's embedding layer
        return self.model.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        # Set the input embeddings in the model
        self.model.set_input_embeddings(new_embeddings)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = M2M100Config.from_pretrained(pretrained_model_name_or_path)
        model = cls(config)
        model.model = M2M100Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        return model
MattYoon commented 2 days ago

I haven't tried translation models for the Encoder, but someone else have. Please check https://github.com/kaistAI/LangBridge/issues/9 . They also found that NLLB doesn't perform well.

Also do note that smaller Encoder models will not perform well. Check Appendix D.2.

Kosei1227 commented 2 days ago

Hi! Thank you so much with your informative reply. NLLB is a great encoder but I doubt whether NLLB is capable of producing soft-prompts. Do you have any theoretical and experimental ideas why NLLB won't work in LangBridge?

Thank you

MattYoon commented 2 days ago

I don't have a clear answer for that.

My speculation is that the output representation of NLLB encoder might not be as language agnostic as mT5. mT5 was trained with completely unlabeled multilingual data, so the output representation is naturally language agnostic.

I'm not sure if the same holds for NLLB, since when NLLB was trained, you would explicitly tell the model what the input language is with language tokens. I'm quite sure that will deter the encoder from forming a language agnostic feature at the output.

But again, this is just my speculation and I don't have clear evidence of it.