pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.88k stars 349 forks source link

[RFC] Adding RoPE scaling methods to support long context modeling #1183

Open joecummings opened 1 month ago

joecummings commented 1 month ago

Background

For large document understanding or tasks like code completion, it's often beneficial to have a large context length e.g. > 8K. In order for this to be enabled by default, a model would have to pretrain on large amounts of data with a long context length. This is costly in time and therefore money. However, Chen et. al found that you could take advantage of the relative nature of RoPE and fine tune a smaller context model to be "long-context" with some changes to the embeddings calculation.

There are several popular methods for scaling RoPE:

  1. Position interpolation: The original rope scaling method, which involves scaling the theta/freq matrix of RoPE by s = new context length / old context length.
  2. NTK-aware: Instead of scaling every dimension of theta/freqs by a factor of s, scaling high frequencies less and low frequencies more.
  3. NTK-by-parts: Use a piecewise linear function to smooth out the effect between original values and interpolated values
  4. YaRN: NTK-by-parts + attention scaling

Thanks to EleutherAI blog post for the very clear explanation of the above concepts

There is evidence of some community appetite for this support: #1120

Proposal

All of the current methods for RoPE scaling really only modify the theta/freq matrix. Therefore, we can support scaling with minimal code changes.

class YaRN:
    def __init__(self, alpha: int, beta: int, scaling_factor: float, ....) -> None:
        ...

    def __call__(self, theta: torch.Tensor) -> torch.Tensor:
        # Scale theta according to YaRN algo
        ...
        return new_theta

class RotaryPositionalEmbeddings(nn.Module):
    def __init__(
        self,
        dim: int,
        max_seq_len: int = 4096,
        base: int = 10_000,
        scaling_module: Optional[Callable] = None,
        scaling_factor_for_attn: float = 1.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_seq_len = max_seq_len
        self.scaling_module = scaling_module
        self.scaling_factor_for_attn = scaling_factor_for_attn
        self._rope_init()

    def reset_parameters(self):
        self._rope_init()

    def _rope_init(self):
        theta = 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
        )
        self.register_buffer("theta", theta, persistent=False)
        self.build_rope_cache(self.max_seq_len)

    def build_rope_cache(self, max_seq_len: int = 4096) -> None:
        if self.scaling_module is not None:
            self.theta = self.scaling_module(theta)

        # Create position indexes `[0, 1, ..., max_seq_len - 1]`
        seq_idx = torch.arange(
            max_seq_len, dtype=self.theta.dtype, device=self.theta.device
        )

        # Outer product of theta and position index; output tensor has
        # a shape of [max_seq_len, dim // 2]
        idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()

        # cache includes both the cos and sin components and so the output shape is
        # [max_seq_len, dim // 2, 2]
        cache = torch.stack([torch.cos(idx_theta) * self.scaling_factor_for_attn , torch.sin(idx_theta) * self.scaling_factor_for_attn], dim=-1)
        self.register_buffer("cache", cache, persistent=False)

    def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor:           
         # Nothing changes in the forward pass  
         ...

This would immediately unlock us to support pre-trained models that have longer context lengths like Phi-3 Mini 128K.

One of the main reasons this is slightly preferred over the alternatives is that it's becoming clear that model families are implementing "YaRN-inspired" or "NTK-inspired" scaling methods (Phi3). This means in order to enable a 1:1 implementation of these models, we need to implement the model-specific version of RoPE scaling, rather than a canonical YaRN, etc. So, instead of creating an entirely new ModelXRotaryPositionalEmbedding, which would proliferate our somewhat slow RoPE throughout the repository and add an additional larger component to test individually, we create a model-specific ModelXScalingModule that we can pass into the RoPE embeddings.


class Phi3ScalingModule:
    def __init__(self, ...):
        ...

    def __call__(self, theta: torch.Tensor) -> torch.Tensor:
        # Modify theta according to Phi3Scaling
        ...

def phi3_mini(num_heads, max_seq_len, ...) -> TransformerDecoder:
    if max_seq_len == 128K:
        rope = RotaryPositionalEmbeddings(
            dim=dim,
            ...,
            scaling_module=Phi3ScalingModule(),
        )
    ...

Alternatives

The above method assumes that scaling of RoPE will only modify the underlying theta/freq matrix. If we want to be less prescriptive in our implementation, we could do a whole new YaRNScaledRotaryPositionalEmbeddings:


class YaRNScaledRotaryPositionalEmbeddings(nn.Module):
    def __init__(
        self,
        dim,
        alpha,
        beta,
        base,
        scaling_factor,
    ) -> None:
        self.dim = dim
        self.alpha = alpha
        ...
        self.build_scaled_rope()

    def build_scaled_rope(self):
        # 1. Construct theta
        # 2. Scale theta according to YaRN algo
        ...  

    def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor:           
         # Same forward pass as RotaryPositionalEmbeddings
         ...

We would then have one of these XScaledRotaryPositionalEmbeddings for every scaling method out there.

FAQs

1. How would this extend to models that aren't long-context by default?

This is an interesting question b/c technically once we have YaRN, RoPE, etc in place we can add it to a "normal" model like Llama3, fine tune it with a long context dataset and boom you now have have a Llama3-128K model. I think the best possible way to let people leverage this is to train our own model (I'll do it or something, not endorsed by Meta, please don't get legal involved) and then write up a quick blog post on how I did it so others can follow suit.

2. Do we need any additional memory/performance considerations to make this work?

There are several possible bottlenecks besides RoPE for long context. First is the time it takes to run attention calculations for a context length of up to 128K. To start, I think we make this issue known to users and consider adding sliding window attention or pre-filling to mitigate this performance hit. Related to this first issue is the varying length of data common in datasets. Unless a user is utilizing a specific long-context dataset where every sample is > 100K context length, there will be a lot of wasted padding space in batches. The mitigation here is to recommend using sample packing to train efficiently.

3. How will testing work?

Similar to [how we test the RoPE embeddings], we will run the canonical script for whichever scaling method we are adding, grab those numbers, and compare against them in a unit test. The somewhat annoying thing is that the majority of scaling methods are integrated into an HF-style RoPE embeddings implementation. This version of RoPE requires a permutation of the QKV values to obtain the same end result. Therefore, we can't easily to a 1:1 mapping and instead have to test that the end results of the attention calculation are the same OR apply a permutation as well in our test OR unpermute the values used in that RoPE calculation. We could consider adopting the HF-style RoPE embeddings b/c is it faster but that is out of scope for this work. Something to consider though is that if the HF-style RoPE embeddings are the default, any testing related to advancements to attention or RoPE will be harder to test until we have a similar implementation in torchtune.

4. How can we leverage community contributions?

It'll be pretty easy and valuable to define a task for contributors to add NTK-by-parts, PI, etc. to the repo so users can play around with different scaling methods.

pbontrager commented 1 month ago

Thanks for this great writeup. I'm not an expert on positional embeddings, but the thing that makes me a bit nervous about this approach is that already there are scaling approaches that can't be contained by the scaling module abstraction (e.g. YaRN). In the examples you provide, it looks like all of the scaling work happens inside of build_rope_cache. Would it be feasible to just allow setting a custom build_rope_cache function to support all of the scaling methods in a non-prescriptive way and still avoid excessive code duplication?

kartikayk commented 1 month ago

Thanks for writing this up! As a reference, I found this to be quite an interesting read as well, especially around the experiments.

A few questions:

This means in order to enable a 1:1 implementation of these models, we need to implement the model-specific version of RoPE scaling, rather than a canonical YaRN

So this generally makes sense to me. I wonder if we've considered a more composable approach i.e. creating a YaRNScaledRotaryPositionalEmbeddings or model specific RoPE class which has a RotaryPositionalEmbeddings object as a field which is created inside the init? So you create the RoPE object, scale the theta/freq matrix, build the cache and scale the cos/sin embeddings. Essentially this is a wrapper around the RoPE embedding which is what we really want? The benefits are:

Do we need any additional memory/performance considerations to make this work?

For training on long context, do we need to update the max_seq_len of the model and make it a param? For example, let's say I want to train on 16K for a model which supports 8K sequence lengths. Does the cache need to be updated to a size of (16K, dim)? If so, this will impact memory? Or did I get this all wrong?

fine tune it with a long context dataset

Do you have examples of these? Would love to learn more about some of the datasets which have long sequences.

We could consider adopting the HF-style RoPE embeddings b/c is it faster

Yeh I think we should pull the trigger soon. @janeyx99 has been patiently waiting for us to do this. BTW what makes their implementation faster? Did we ever figure that out?

felipemello1 commented 1 month ago

Great RFC!

The focal point seems to be "All of the current methods for RoPE scaling really only modify the theta/freq matrix".

@pbontrager mentioned that "makes me a bit nervous about this approach is that already there are scaling approaches that can't be contained by the scaling module abstraction (e.g. YaRN)" -- my understanding is that YaRN does fit this design, right? I do agree with the point though.

I think that @kartikayk proposal about modularity is interesting. For example, how would we support ALIBI in torchtune, or any new positional embedding that is not Rotary? If we were to support it, could/should we re-use the same builder we are using for Rotary?

For simplicity, given the models we currently support, I think that the proposed design would work.

Q: For testing, do we need to finetune on long context, or running inference is good enough?

Nit: I think that pre-filling only works for inference.

joecummings commented 1 month ago

Thanks for this great writeup. I'm not an expert on positional embeddings, but the thing that makes me a bit nervous about this approach is that already there are scaling approaches that can't be contained by the scaling module abstraction (e.g. YaRN). In the examples you provide, it looks like all of the scaling work happens inside of build_rope_cache. Would it be feasible to just allow setting a custom build_rope_cache function to support all of the scaling methods in a non-prescriptive way and still avoid excessive code duplication?

@pbontrager

Wanted to echo @felipemello1's comments that actually YaRN is covered by the scenario above; however, I understand the overall sentiment. I think passing in a function might be a little strange, but it does make me think that actually we could make it so theta is constructed outside of the RoPE and simply passed into the class as an init argument. Then, the class would handle the forward calculation. Something like this:

def build_theta_for_yarn(dim, base, alpha, beta, scaling_factor) -> torch.Tensor:
    # build theta matrix scaled according to YaRN algo
    ...

class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, dim, theta):
        self.dim = dim
        self.theta = theta

    def forward(self, x, input_pos):
        # same forward as before
        ...

def model_xyz(*args) -> TransformerDecoder:
    theta = build_theta_for_yarn(dim, base, alpha, beta, scaling_factor)
    rope = RotaryPositionalEmbeddings(
        dim=dim,
        theta=theta,
    )
    ...

This seems more similar to the original and current HF rope implementations which have separate functions for constructing the theta/freq matrix and for applying it during the attention calculation.

First, I want to confirm that the above implementation is similar to what you had in mind?

I think having the two completely separate does make it confusing for a user who's not an expert in positional embeddings and the difference between normal RoPE and scaled RoPE is not entirely clear, but I do agree it is slightly more future-proof.

joecummings commented 1 month ago

So this generally makes sense to me. I wonder if we've considered a more composable approach i.e. creating a YaRNScaledRotaryPositionalEmbeddings or model specific RoPE class which has a RotaryPositionalEmbeddings object as a field which is created inside the init? So you create the RoPE object, scale the theta/freq matrix, build the cache and scale the cos/sin embeddings. Essentially this is a wrapper around the RoPE embedding which is what we really want?

Alright let me try to understand this by implemented in code:

class YaRNScaledRotaryPositionalEmbeddings(nn.Module):
    def __init__(self, rope, *, dim, alpha, beta, scaling_factor, scaling_factor_for_attn):
        ...
        self._scale_rope(rope, dim, alpha, beta, scaling_factor, scaling_factor_for_attn)

    def _scale_rope(rope, alpha, beta, scaling_factor):
        curr_theta = rope.theta
        # scale theta according to YaRN algo
        ...
        self.register_buffer("theta", new_theta, peristent=False)

    def forward(self, x, *, input_pos):
        # copy forward from rope
        ...

def model_xyz(args) -> TransformerDecoder:
    rope = RotaryPositionalEmbeddings(dim, ...)
    yarn_rope = YaRNScaledRotaryPositionalEmbeddings(
        rope=rope,
        dim=dim,
        alpha=alpha,
        ...
    )
    ...

In my mind, this is actually more confusing. We pass in the entire RoPE module, but only use it for the theta value and just copy pasta over the forward method. This will also proliferate our "slow" RoPE implementation around the repo, which I know we're keen to remove.

  • you can define the logic for the scaling module within the class instead of a separate scaling function which is a more natural abstraction IMO

I would argue this could happen in the YaRN class as defined in my original proposal, which neatly contains all the code needed for scaling theta.

  • you can specify default values for scaling_factor_for_attn and tie this to the given implementation. In the current proposal it seems like the burden is on the user to correctly define this and pass to RoPE which makes be nervous

Similar to how the burden is on our builder functions to provide the right values, we would also provide the appropriate values for a Llama3ScalingModule by default.

  • the builder function uses the ACTUAL RoPE module instead of the generic one which is more readable

I can definitely understand this thinking, but when the whole model definition is printed out, it will note the YaRN class as an added module within the RoPE embedding (although might need to defined as a nn.Module for this, not sure...)


An alternative I can see, since we only use the theta value at the moment is something like this wherein we don't pass around the RoPE module, but rather just the theta:

class YaRNScaledRotaryPositionalEmbeddings(nn.Module):
    def __init__(self, original_theta, *, dim, alpha, beta, scaling_factor, scaling_factor_for_attn):
        ...
        self._scale_rope(original_theta, dim, alpha, beta, scaling_factor, scaling_factor_for_attn)

    def _scale_rope(original_theta, alpha, beta, scaling_factor):
        # scale theta according to YaRN algo
        ...
        self.register_buffer("theta", new_theta, peristent=False)

    def forward(self, x, *, input_pos):
        # copy forward from rope
        ...

def model_xyz(args) -> TransformerDecoder:
    rope = RotaryPositionalEmbeddings(dim, ...)
    yarn_rope = YaRNScaledRotaryPositionalEmbeddings(
        original_theta=rope.theta,
        dim=dim,
        alpha=alpha,
        ...
    )
    ...

It's still some wasted overhead and code, but is a clearer delineation between what is needed for the scaling calculation.


Overall, I think I'd prefer complete code duplication in a class like YaRNScaledRotaryPositionalEmbeddings that doesn't reuse any previous RoPE components over the above options.

joecummings commented 1 month ago

Do you have examples of these? Would love to learn more about some of the datasets which have long sequences.

There's some more canonical ones like Long Alpaca, but that only goes up to 12K context length. Then there's some synthetic data that goes super long like the work being done by Gradient AI. The majority ATM are code datasets like code alpaca 20k and arxiv research code. I expect this to change as people want more and more context for their models.

joecummings commented 1 month ago

For training on long context, do we need to update the max_seq_len of the model and make it a param? For example, let's say I want to train on 16K for a model which supports 8K sequence lengths. Does the cache need to be updated to a size of (16K, dim)? If so, this will impact memory? Or did I get this all wrong?

This is a great question. We will need to keep track of original model length and new model length as that affects the scaling factor. And in this original implementation, yes the cache will need to be larger taking up a lot of memory. Realistically, I doubt this change - without SWA - will get us to 128K seq length without OOM. SWA should be considered a quick follow up for this type of fine tuning.

joecummings commented 1 month ago

Yeh I think we should pull the trigger soon. @janeyx99 has been patiently waiting for us to do this. BTW what makes their implementation faster? Did we ever figure that out?

They appear to do two less multiplications and they work with 2D matrices instead of 3D.

joecummings commented 1 month ago

@felipemello1

Q: For testing, do we need to finetune on long context, or running inference is good enough?

We will definitely need to fine-tune a long context model to confirm that this is feasible and works for our users.

kartikayk commented 1 month ago

@joecummings sorry I wasn't clear with my proposal around composability. It's not what you've coded above. Though as I think about this, I realize the design needs to account for the fact that we'll be adding other "base variants" (eg: FastRoPE and FasterRoPE) which then needs there own scaled versions? Is that the right way to think about this? If so, then I take my proposal back since I think explicitly passing the scaling_module might be the only way to prevent RoPE proliferation i.e. if we have n base variants and m scaled versions then my proposal will lead to n*m classes. Not sure, I'll need to think harder about this, but since I don't have a better solution at the moment, I'm good with the proposal at the top.

RdoubleA commented 1 month ago

This seems more similar to the original and current HF rope implementations which have separate functions for constructing the theta/freq matrix and for applying it during the attention calculation.

First, I want to confirm that the above implementation is similar to what you had in mind?

I think having the two completely separate does make it confusing for a user who's not an expert in positional embeddings and the difference between normal RoPE and scaled RoPE is not entirely clear, but I do agree it is slightly more future-proof.

This theta function approach is interesting but I don't have a good overview of all the possible methods for modifying/scaling theta. Do you anticipate or know of a method that wouldn't fit as a "scaling" method?

I see the added complexity of a scaling module or a theta function as quite similar in terms of UX, and the latter gives a bit more flexibility. Would there be a default theta function for non-scaled RoPE?

If you have reasonable confidence that model-specific RoPE implementations will only impact scaling, then your original approach makes sense. If there's some possibility that models may need full control over theta, then I would vote for the theta function, since it does not worsen UX compared to the original imo.