pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.52k stars 21.87k forks source link

[Feature Request] Lazy upcasting for mmap'd state dicts #130480

Open muellerzr opened 1 month ago

muellerzr commented 1 month ago

In transformers as a rule we load models always in as float32 for stability, even if the weights are in bfloat16. As a result, loading llama-3-8B can't be done lazily via mmap, since we have to upcast all the values in the state_dict immediately leading to some pretty slow timings.

It'd be nice if there were an API exposed that could take a lazy-loaded state_dict and hook into it to where once we request something from that state dict (before it's read from disk etc), we convert that parameter to .to(float32) automatically (or whatever precision you may want).

transformers PR where this behavior would be very useful: https://github.com/huggingface/transformers/pull/31771#discussion_r1672525371

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @msaroufim

muellerzr commented 1 month ago

Or, if this exists in PyTorch core I don't know, a hook when loading in the model weight via mmap to perform some operation on it?

msaroufim commented 1 month ago

Actually thinking about this more but maybe we should discuss this on pytorch/pytorch since @mikaylagawarecki and @albanD can give a better answer than me

mikaylagawarecki commented 1 month ago

This seems like a potential use case for torch.Tensor.module_load, which can hook into the nn.Module.load_state_dict call.

If the model definition contains a __torch_function__ tensor subclass for all parameters, and the state_dict was loaded via mmap, YourTensorSubclass.module_load could be defined via __torch_function__ to do the transformation from bfloat16 to float32 when each individual parameter is being loaded.

Will look into this further!

muellerzr commented 1 month ago

That sounds perfect @mikaylagawarecki, looking forward to it!

muellerzr commented 1 month ago

My other fear with this is how easy it is to scale (e.g. can it be generic enough to apply via inheritance? and then all transformer's models can make use of this functionality if configured).

A toy example to see how __torch_function__ would work in this case would be exceedingly helpful :)

mikaylagawarecki commented 1 month ago

Hey @muellerzr, does this snippet achieve what you are looking for?

import torch
import torch.nn as nn
from torch.overrides import TorchFunctionMode

model = nn.Transformer

# Assume we had saved the state_dict like such 
# sd = model(dtype=torch.bfloat16).state_dict()
# torch.save(sd, "sd_bf16.pt")

class UpcastMode(TorchFunctionMode):
    '''
       __torch_function__ mode that overrides Tensor.module_load behavior to always upcast
       value in the state_dict to dtype of param in the model
    '''
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func == torch.Tensor.module_load:
            dest, src = args[0], args[1]
            return src.to(dest.dtype).detach()
        return func(*args, **kwargs)

sd_loaded = torch.load("sd_bf16.pt", mmap=True, weights_only=True)
with UpcastMode(), torch.device("meta"):
    m = model(dtype=torch.float32)
    # For now, need to set this __future__ for `module_load` to be called.
    torch.__future__.set_swap_module_params_on_conversion(True)
    m.load_state_dict(sd_loaded, assign=True)
    torch.__future__.set_swap_module_params_on_conversion(False)
muellerzr commented 1 month ago

@mikaylagawarecki close but not quite! We still wind up loading the state dict in there, rather than delaying it as far as we can (aka an input goes through to the model). I'm not sure if what I want there actually can be possible.

As this is not really different from just going through and doing .to(dest.dtype) manually after loading in the state dict ourselves no?

mikaylagawarecki commented 1 month ago

Oh, I see what you mean, hmmm

I am slightly surprised that the decrease in first pass throughput is so small when lazy loading when the first input is passed to the model as compared to the model init time before :o

Curious, for the throughput numbers for the first pass on https://github.com/huggingface/transformers/pull/31771#issue-2388519274 is the page cache cleared each time before mmap-ing the checkpoint?

muellerzr commented 1 month ago

Probably not, I'm not sure how to do that so some advice would be nice! :) However the hardware I'm working with (M.2 drives) can read up to 14.5GB/s, so it's not unreasonable to read llama-8B's data so fast it seems almost instantaneous

mikaylagawarecki commented 1 month ago

I'm not sure if there's a way to do it from python, but on linux I think it would be sudo sysctl vm.drop_caches=1

muellerzr commented 1 month ago

@mikaylagawarecki destroyed the cache between runs.

New timings (llama-3-8B in bfloat16 on CPU)

Model init:

W/o lazy loading: 3.025 seconds W/ lazy loading: 0.319 seconds

First pass:

W/o lazy loading: 2.353 tok/s (total time == 8.499s) W/ lazy loading: 2.020 tok/s (total time == 9.903s)

At ~16gb for the weights, that's ~11.4GB/s (which makes sense for the speed of my m.2)

Second pass:

W/o lazy loading: 2.444 tok/s (total time == 8.182s) W/ lazy loading: 2.434 tok/s (total time == 8.218s)