dvmazur / mixtral-offloading

Run Mixtral-8x7B models in Colab or consumer desktops
MIT License
2.29k stars 227 forks source link

Run without quantization #22

Open freQuensy23-coder opened 10 months ago

freQuensy23-coder commented 10 months ago

QuantConfig is mandatory of make model function


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

Can I run mixtral with layer offloading, but WITHOUT quntization using this library?

dvmazur commented 10 months ago

What hardware do you plan running the model on? It would require quite the amount of combined RAM + VRAM to run the model without quantization.

freQuensy23-coder commented 10 months ago

I'll use Tesla A100 with 80 gb vram + 512 ram

dvmazur commented 10 months ago

Yeah, sound like it'll fit :D

The current codebase doesn't support running the model without quantization, but you could try rewriting the expert wrapper class.

This class moves the expert's parameters to a single storage, so it later can be efficiently moved between GPU and CPU memory. Here's a snippet that does this for the original expert class:

def replace_layer_storage(layer, device):
    state_dict = layer.state_dict()

    storage_size = 0
    offsets = [0]

    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            continue
        storage_size += x.nbytes
        offsets.append(storage_size)

    storage = torch.UntypedStorage(storage_size, device=device) 

    i = 0
    new_flattened_states = list()
    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            new_flattened_states.append(x)
            continue

        start = offsets[i]
        end = offsets[i + 1]
        a_view = torch.as_tensor(storage[start:end], dtype=x.dtype, device=device).view(x.shape)
        a_view[...] = x
        assert a_view.data_ptr() == storage.data_ptr() + start
        i += 1
        new_flattened_states.append(a_view)

    state_dict = nested_pack(new_flattened_states, state_dict)

    for name, param in layer.named_parameters():
        param.data = state_dict[name]

    return layer, storage

The rest of the codebase is still quite HQQ-specific and offloading the unquantized model will require rewriting some code in the build_model.py file. Most of it boils down to replacing HQQ layers with default pytorch ones, though.

If you decide to go down that path, I can help you out a bit in this issue :)

lavawolfiee commented 10 months ago

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

freQuensy23-coder commented 10 months ago

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

I'll unload some of experts to RAM during inference, and it will use less gpu vram. It's the main idea of this lib. @dvmazur am i right

freQuensy23-coder commented 10 months ago

If you decide to go down that path, I can help you out a bit in this issue :)

Thanks, I’d appreciate your help with this. Also i 'll try to do it myself today's evening.

dvmazur commented 10 months ago

@freQuensy23-coder, yes, you are right - @lavawolfiee must have misunderstood you.

freQuensy23-coder commented 10 months ago

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

dvmazur commented 10 months ago

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

The snippet I sent you doesn't use quantization. It simply puts a given layer to one single storage.